/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef _WIN32 #include #endif #include #include #include #include "ir/anf.h" #include "pybind_api/ir/primitive_py.h" #include "ir/meta_func_graph.h" #include "ir/func_graph_cloner.h" #include "ir/manager.h" #include "pipeline/jit/resource.h" #include "pipeline/jit/parse/parse.h" #include "pipeline/jit/parse/resolve.h" #include "frontend/optimizer/ad/dfunctor.h" #include "frontend/operator/ops.h" #include "frontend/operator/composite/composite.h" #include "utils/utils.h" #include "utils/symbolic.h" #include "utils/primitive_utils.h" #include "utils/ms_context.h" #include "utils/info.h" #include "debug/trace.h" #include "debug/common.h" #include "debug/dump_proto.h" #include "mindspore/core/load_mindir/load_model.h" #include "utils/system/sha256.h" #include "utils/file_utils.h" namespace mindspore { namespace ad { KPrim g_k_prims; namespace { constexpr char kBpropMindIRSuffix[] = "_bprop.mindir"; constexpr char kBpropMindIRDir[] = "/../bprop_mindir/"; constexpr char serializable_bprop_ops[] = "serializable_bprop_ops"; constexpr char bprop_mindir_module[] = "mindspore.ops.bprop_mindir"; #ifndef _WIN32 std::string GetBpropDir() { static std::string bprop_dir; if (bprop_dir.empty()) { py::module mod = py::module::import("mindspore.ops._grad"); auto grad_file_path = mod.attr("__file__").cast(); bprop_dir = grad_file_path.substr(0, grad_file_path.find_last_of('/')); } return bprop_dir; } bool BpropMindirDirExists() { auto bprop_mindir_dir = GetBpropDir() + kBpropMindIRDir; DIR *dir = opendir(bprop_mindir_dir.c_str()); if (dir != nullptr) { if (closedir(dir) == -1) { MS_LOG(WARNING) << "The bprop mindir dir \"" << bprop_mindir_dir << "\" close failed!"; } return true; } MS_LOG(ERROR) << "Open bprop mindir dir \"" << bprop_mindir_dir << "\" failed." << ErrnoToString(errno); return false; } // Get the serializable bprop list from the module mindspore.ops.bprop_mindir in python. mindspore::HashSet GetSerializableBpropList() { mindspore::HashSet serializable_bprop_list; if (!BpropMindirDirExists()) { return serializable_bprop_list; } py::module mod = py::module::import(bprop_mindir_module); py::object serializable_bprop_ops_attr = mod.attr(serializable_bprop_ops); if (!py::isinstance(serializable_bprop_ops_attr)) { MS_LOG(WARNING) << "Can not get the the serializable bprop ops list from python, it is not a python list."; return serializable_bprop_list; } auto ops_list = serializable_bprop_ops_attr.cast(); for (size_t i = 0; i < ops_list.size(); ++i) { auto prim_adapter = ops_list[i].cast(); if (prim_adapter == nullptr) { MS_LOG(EXCEPTION) << "The python obj in serializable bprop list should be a Primitive, but it is " << py::str(ops_list[i]); } serializable_bprop_list.insert(prim_adapter->name()); } return serializable_bprop_list; } bool IsSerializableBprop(const std::string &prim_name) { static mindspore::HashSet serializable_bprop_list = GetSerializableBpropList(); return std::any_of(serializable_bprop_list.begin(), serializable_bprop_list.end(), [&prim_name](const std::string &serializable_bprop_prim_name) { return prim_name == serializable_bprop_prim_name; }); } void GetFilesHash(const std::string &dir, mindspore::HashMap *bprop_hash_to_file) { if (dir.empty()) { MS_LOG(ERROR) << "The directory path is empty."; return; } struct stat s {}; int ret = stat(dir.c_str(), &s); if (ret != 0) { MS_LOG(ERROR) << "stat dir \"" << dir << "\" failed, ret is : " << ret; return; } if (!S_ISDIR(s.st_mode)) { MS_LOG(ERROR) << "The path \"" << dir << "\" is not a directory."; return; } DIR *open_dir = opendir(dir.c_str()); if (open_dir == nullptr) { MS_LOG(ERROR) << "open dir " << dir.c_str() << " failed"; return; } struct dirent *filename; while ((filename = readdir(open_dir)) != nullptr) { std::string d_name = std::string(filename->d_name); if (d_name == "." || d_name == ".." || filename->d_type != DT_REG) { continue; } auto real_path = std::string(dir) + "/" + filename->d_name; (void)bprop_hash_to_file->emplace(system::sha256::GetHashFromFile(real_path), real_path); } closedir(open_dir); } mindspore::HashMap GetAllBpropFileHash() { mindspore::HashMap bprop_hash_to_file; auto bprop_dir = GetBpropDir(); auto realpath = FileUtils::GetRealPath(common::SafeCStr(bprop_dir)); if (!realpath.has_value()) { MS_LOG(EXCEPTION) << "Get real path of bprop dir failed. path=" << bprop_dir; } GetFilesHash(realpath.value(), &bprop_hash_to_file); return bprop_hash_to_file; } bool CheckBpropHash(const std::string &hash) { // Get every hash of all the bprop files. static auto bprop_hash_to_file = GetAllBpropFileHash(); if (bprop_hash_to_file.find(hash) != bprop_hash_to_file.end()) { return true; } std::string bprop_dir = GetBpropDir(); auto bprop_mindir_path = bprop_dir + kBpropMindIRDir; MS_LOG(ERROR) << "The bprop mindir files are not up to date. Please run the " << bprop_mindir_path << "generate_mindir.py to generate new mindir files.\n" << "bprop_fg hash: " << hash << "\n" << "bprop hash list: \n"; for (const auto &iter : bprop_hash_to_file) { MS_LOG(ERROR) << iter.first; } return false; } FuncGraphPtr ImportBpropFromMindIR(const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(prim); std::string bprop_dir = GetBpropDir(); auto bprop_mindir_path = bprop_dir + kBpropMindIRDir; std::optional bprop_mindir_realpath = FileUtils::GetRealPath(common::SafeCStr(bprop_mindir_path + prim->name() + kBpropMindIRSuffix)); bool bprop_cache_file_exists = bprop_mindir_realpath.has_value() && Common::FileExists(bprop_mindir_realpath.value()); if (!bprop_cache_file_exists) { return nullptr; } MindIRLoader mindir_loader; auto bprop_fg = mindir_loader.LoadMindIR(bprop_mindir_realpath.value()); if (!CheckBpropHash(bprop_fg->bprop_hash())) { MS_LOG(EXCEPTION) << "The bprop mindir files are not up to date."; } return bprop_fg; } void ExportBpropToMindIR(const PrimitivePtr &prim, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(prim); std::string bprop_dir = GetBpropDir(); auto bprop_mindir_path = bprop_dir + kBpropMindIRDir; std::optional bprop_mindir_realpath = Common::CreatePrefixPath(bprop_mindir_path + prim->name() + kBpropMindIRSuffix, true); if (!bprop_mindir_realpath.has_value()) { MS_LOG(ERROR) << "Failed to get the realpath of bprop mindir: " << bprop_mindir_path << prim->name() << kBpropMindIRSuffix; return; } std::ofstream fout(bprop_mindir_realpath.value()); if (!fout.is_open()) { MS_LOG(ERROR) << "Open cache file '" << bprop_mindir_realpath.value() << "' failed!" << ErrnoToString(errno); return; } ModelProtoPtr fg_model = GetBinaryProto(func_graph); if (fg_model == nullptr) { MS_LOG(ERROR) << "Get binary proto for graph " << func_graph->ToString() << " failed."; fout.close(); return; } if (!fg_model->SerializeToOstream(&fout)) { MS_LOG(ERROR) << "Failed to cache the bprop of op \"" << prim->name() << "\" to file \"" << bprop_mindir_realpath.value() << "\"."; fout.close(); return; } fout.close(); ChangeFileMode(bprop_mindir_realpath.value(), S_IRUSR | S_IWUSR); } AnfNodePtr GetPythonOps(const FuncGraphPtr &fg, const AnfNodePtr &origin_node, const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(origin_node); MS_EXCEPTION_IF_NULL(prim); // DoSignaturePrimitive to the pair of primitive name and module name. static mindspore::HashMap> python_ops{ {"S-Prim-zeros_like_leaf", {"zeros_like", ""}}, {"S-Prim-getitem", {"getitem", "mindspore.ops.composite.multitype_ops.getitem_impl"}}}; auto iter = python_ops.find(prim->name()); if (iter == python_ops.end()) { return nullptr; } ValuePtr python_ops_value; if (!iter->second.second.empty()) { python_ops_value = prim::GetPythonOps(iter->second.first, iter->second.second); } else { python_ops_value = prim::GetPythonOps(iter->second.first); } auto origin_cnode = origin_node->cast(); MS_EXCEPTION_IF_NULL(origin_cnode); auto &origin_inputs = origin_cnode->inputs(); std::vector new_inputs{NewValueNode(python_ops_value)}; (void)std::copy(origin_inputs.begin() + 1, origin_inputs.end(), std::back_inserter(new_inputs)); return fg->NewCNode(new_inputs); } // Replace the nodes whose python obj of primitive is needed in the renormalize process, // with the new created python ops, such as zeros_like. void ReplacePythonOps(const FuncGraphPtr &fg) { MS_EXCEPTION_IF_NULL(fg); std::vector all_nodes = DeepScopedGraphSearch(fg->get_return()); for (const auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; } auto cnode = node->cast(); for (size_t i = 0; i < cnode->size(); ++i) { auto prim = GetCNodePrimitive(cnode->input(i)); if (prim == nullptr) { continue; } auto new_input = GetPythonOps(fg, cnode->input(i), prim); if (new_input == nullptr) { continue; } cnode->set_input(i, new_input); } } } std::string GetBpropFileHash(const py::function &fn) { static auto bprop_hash_to_file = GetAllBpropFileHash(); // Get the file where the bprop function is defined. auto filename = fn.attr("__code__").attr("co_filename").cast(); // Get the hash of the file. auto it = std::find_if(bprop_hash_to_file.begin(), bprop_hash_to_file.end(), [&filename](const auto &item) { return item.second == filename; }); if (it != bprop_hash_to_file.end()) { return it->first; } return ""; } #endif } // namespace #ifndef _WIN32 // Given a python primitive, export a mindir file from the bprop defined in python. void KPrim::ExportBpropMindir(const py::object &obj) { auto prim_adapter = obj.cast(); if (prim_adapter == nullptr) { MS_LOG(EXCEPTION) << "The python obj to be exported to bprop mindir should be a Primitive, but it is " << py::str(obj); } auto prim = prim_adapter->attached_primitive(); if (prim == nullptr) { prim = std::make_shared(obj, prim_adapter); prim_adapter->set_attached_primitive(prim); } // Get the bprop function from python. py::function fn = prim->cast()->GetBpropFunction(); if (py::isinstance(fn)) { fn = GetBpropFunction(prim->name()); } if (!fn || py::isinstance(fn)) { MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; } std::string bprop_hash = GetBpropFileHash(fn); if (bprop_hash.empty()) { MS_LOG(EXCEPTION) << "Fail to get the file hash for " << prim->name(); } // Parse and resolve. auto func_graph = parse::ParsePythonCode(fn); if (func_graph == nullptr) { MS_LOG(EXCEPTION) << "Fail to parse bprop function for " << prim->name() << "."; } auto res = std::make_shared(); (void)parse::ResolveFuncGraph(func_graph, res); func_graph->set_bprop_hash(bprop_hash); ExportBpropToMindIR(prim, func_graph); } #endif FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) { // Set a child scope named "grad'PrimitiveName'" for the bprop function, // and add "Gradients" to the front. static const std::string gradients_scope = "Gradients/"; static const std::string grad_op_child_scope_prefix = "/grad"; MS_EXCEPTION_IF_NULL(prim); auto scope = std::make_shared(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() + grad_op_child_scope_prefix + prim->name()); ScopeGuard scope_guard(scope); // Firstly we get bprop from mindir. If failed, parse the python function registered. FuncGraphPtr func_graph = nullptr; #ifndef _WIN32 bool serializable = IsSerializableBprop(prim->name()); if (serializable) { func_graph = ImportBpropFromMindIR(prim); if (func_graph != nullptr) { ReplacePythonOps(func_graph); if (GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP)) { func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true); } return func_graph; } } #endif py::function fn; if (prim->is_base()) { fn = GetBpropFunction(prim->name()); } else { fn = prim->cast()->GetBpropFunction(); if (py::isinstance(fn)) { fn = GetBpropFunction(prim->name()); } } if (!fn || py::isinstance(fn)) { MS_LOG(WARNING) << "Fail to find bprop function for " << prim->name() << ". fn: " << py::str(fn); return nullptr; } func_graph = parse::ParsePythonCode(fn); if (func_graph == nullptr) { MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << "."; return nullptr; } auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP); if (bprop_flag) { func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true); } pipeline::ResourceBasePtr res = (resources != nullptr) ? resources : std::make_shared(); (void)parse::ResolveFuncGraph(func_graph, res); #ifndef _WIN32 // Check whether the bprop needs to be exported. if (serializable) { std::string bprop_hash = GetBpropFileHash(fn); if (!bprop_hash.empty()) { func_graph->set_bprop_hash(bprop_hash); ExportBpropToMindIR(prim, func_graph); } } #endif return func_graph; } FuncGraphPtr KPrim::GetPossibleBprop(const PrimitivePtr &prim) { FuncGraphPtr bprop_fg = nullptr; auto iter = bprop_registry_.find(prim); if (iter != bprop_registry_.end()) { bprop_fg = iter->second; } if (bprop_fg == nullptr) { bprop_fg = GetBprop(prim); if (bprop_fg != nullptr) { // Set bprop_g graph cache bprop_registry_[prim] = bprop_fg; } } return bprop_fg; } FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) { static const std::string ad_module = "mindspore.ops._grad.grad_implementations"; std::string func_name = "_fprop_" + prim->name(); py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name); auto func_graph = parse::ParsePythonCode(fn); MS_EXCEPTION_IF_NULL(func_graph); return BasicClone(func_graph); } MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(prim); auto iter = bprop_registry_meta_.find(prim); if (iter != bprop_registry_meta_.end()) { return iter->second; } if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { MetaFuncGraphPtr meta = std::make_shared("make_tuple_gradient"); bprop_registry_meta_[prim::kPrimMakeTuple] = meta; return meta; } if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) { MetaFuncGraphPtr meta = std::make_shared("make_list_gradient"); bprop_registry_meta_[prim::kPrimMakeList] = meta; return meta; } MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; } static void AppendMonadOutput(const FuncGraphPtr &bprop_fg, const AnfNodePtr &monad) { const auto &output = bprop_fg->output(); MS_EXCEPTION_IF_NULL(output); auto output_cnode = output->cast(); if (output_cnode != nullptr) { // If output_cnode has the form like (make_tuple, x, y). output_cnode->add_input(monad); return; } // If output is an empty tuple, create a (make_tuple, monad) as the new output. auto make_tuple = NewValueNode(prim::kPrimMakeTuple); output_cnode = bprop_fg->NewCNode({make_tuple, monad}); bprop_fg->set_output(output_cnode); } // Append U or/and IO monad to output of Bprop funcgraph. static void AdjustForAutoMonad(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) { auto effect_info = GetPrimEffectInfo(prim); if (effect_info.memory) { MS_LOG(DEBUG) << "Append U monad for Bprop FuncGraph of Primitive " << prim->ToString(); auto u = NewValueNode(kUMonad); u->set_abstract(kUMonad->ToAbstract()); AppendMonadOutput(bprop_fg, u); } if (effect_info.io) { MS_LOG(DEBUG) << "Append IO monad for Bprop FuncGraph of Primitive " << prim->ToString(); auto io = NewValueNode(kIOMonad); io->set_abstract(kIOMonad->ToAbstract()); AppendMonadOutput(bprop_fg, io); } } std::vector GeneratePrimalDebugInfo(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { std::vector primal_debug_infos; if (resources != nullptr) { auto manager = resources->manager(); auto &users = manager->node_users()[value_node]; for (auto user_iter = users.begin(); user_iter != users.end(); ++user_iter) { primal_debug_infos.push_back(user_iter->first->debug_info()); } } return primal_debug_infos; } FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { if (!IsValueNode(value_node)) { MS_LOG(EXCEPTION) << "Primitive node is not valid."; } auto prim = GetValueNode(value_node); if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) { auto fprop = GetFprop(prim); fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer)); return fprop; } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { return nullptr; } else if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) { return nullptr; } FuncGraphPtr bprop_fg = nullptr; if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) { if (MsContext::GetInstance()->get_param(MsCtxParam::MS_CTX_EXECUTION_MODE) == kGraphMode) { MS_LOG(EXCEPTION) << "The Primitive 'HookBackward' is not supported in graph mode, which is only supported in pynative mode.\n" << trace::GetDebugInfo(cnode->debug_info()); } bprop_fg = BpropCut(value_node, resources); } else { auto iter = bprop_registry_.find(prim); if (iter != bprop_registry_.end()) { bprop_fg = iter->second; } if (bprop_fg == nullptr) { bprop_fg = GetBprop(prim, resources); if (bprop_fg != nullptr) { // Set bprop_g graph cache bprop_registry_[prim] = bprop_fg; } else { bprop_fg = FakeBprop(value_node, resources); } } } AdjustForAutoMonad(prim, bprop_fg); mindspore::HashMap primal_attrs; std::vector primal_debug_infos = GeneratePrimalDebugInfo(value_node, resources); if (cnode != nullptr) { primal_attrs = cnode->primal_attrs(); const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId(); primal_attrs[kPrimalAttrForwardNodeName] = MakeValue(forward_node_primal_attr); } auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode, primal_attrs, primal_debug_infos); if (expanded_fg == nullptr) { MS_LOG(EXCEPTION) << "Failed convert " << prim->name() << " prim bprop function to J expanded func graph. NodeInfo: " << trace::GetDebugInfo(bprop_fg->debug_info()); } if (lift_fv_before_grad && IsPrimitiveEquals(prim, prim::kPrimSwitch)) { // Inline fprop_switch before renormalize; expanded_fg->set_flag(FUNC_GRAPH_FLAG_FORCE_INLINE, true); MS_LOG(DEBUG) << "set force_inline for fg: " << expanded_fg->ToString(); } return expanded_fg; } AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg) { // current_primal_fg may have extra parameters like u_monad, io_monad std::vector extra_args; // caller had checked size() - 2 is greater than 0. auto bprop_fg_param_size = bprop_fg->parameters().size() - 2; if (current_primal_fg != nullptr && bprop_fg_param_size < current_primal_fg->parameters().size()) { auto current_primal_fg_param_size = current_primal_fg->parameters().size(); MS_LOG(DEBUG) << "Current Primal FuncGraph may have extra parameters(U or IO monad) which bprop don't define, so " "Insert it. Extra parameters size: " << current_primal_fg_param_size - bprop_fg_param_size; for (auto i = bprop_fg_param_size; i < current_primal_fg_param_size; ++i) { const auto &primal_node = current_primal_fg->parameters()[i]; AnfNodePtr extra_node; // Simplify zeros_like(primal_node) to U or IO, so extra_node in bprop_fg will not refer to primal_node // as a free variable of primal_graph. // Notes: if the implementation of zeros_like changes, here too. if (HasAbstractUMonad(primal_node)) { extra_node = NewValueNode(kUMonad); } else if (HasAbstractIOMonad(primal_node)) { extra_node = NewValueNode(kIOMonad); } else { MS_EXCEPTION(TypeError) << "The params of function 'bprop' of Primitive or Cell requires the forward inputs as well " "as the 'out' and 'dout'.\n" << trace::GetDebugInfo(bprop_fg->debug_info()); } extra_args.push_back(extra_node); MS_LOG(DEBUG) << "Insert to bprop_fg for node: " << primal_node->DebugString(); } } // bprop_fg has been checked in caller if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) { // Set bprop output as (env, dx, dy, dz, ...) auto cbprop = bprop_fg->output()->cast(); auto &inputs = cbprop->inputs(); std::vector args; args.push_back(NewValueNode(prim::kPrimMakeTuple)); args.push_back(NewValueNode(newenv)); (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); if (!extra_args.empty()) { args.insert(args.end(), extra_args.cbegin(), extra_args.cend()); } return NewCNode(args, bprop_fg); } // Set bprop output as (env, dx) std::string model_name("mindspore.ops.composite.multitype_ops.add_impl"); std::string python_ops("_tuple_add"); auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg); auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name)); if (!extra_args.empty()) { extra_args.insert(extra_args.begin(), NewValueNode(prim::kPrimMakeTuple)); auto extra_tuple = NewCNode(extra_args, bprop_fg); auto old_output_extra = NewCNode({tuple_add_ops, bprop_fg->output(), extra_tuple}, bprop_fg); return NewCNode({tuple_add_ops, tuple_env, old_output_extra}, bprop_fg); } return NewCNode({tuple_add_ops, tuple_env, bprop_fg->output()}, bprop_fg); } static void TransformNormalArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, std::vector *const transf_args) { // bprop_fg has been checked in caller // transform except the last 2 parameters: out, dout. const size_t last_parameter_sizes = 2; auto bprop_fg_param_size = bprop_fg->parameters().size() - last_parameter_sizes; for (size_t i = 0; i < bprop_fg_param_size; ++i) { auto p = bprop_fg->parameters()[i]; MS_EXCEPTION_IF_NULL(p); TraceGuard trace_guard(std::make_shared(p->debug_info())); auto transf_p = outer->add_parameter(); (void)mng->Replace(p, transf_p); transf_args->push_back(transf_p); } } void KPrim::TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const PrimitivePtr &primitive, const FuncGraphPtr &outer, std::vector *const transf_args) { MS_EXCEPTION_IF_NULL(mng); TransformNormalArgs(mng, bprop_fg, outer, transf_args); // Fprop_fg for Primitive with side effect should append extra U or IO monad parameter. auto effect_info = GetPrimEffectInfo(primitive); if (effect_info.memory) { MS_LOG(DEBUG) << "Append U monad to Fprop FuncGraph for Primitive " << primitive->ToString(); auto transf_p = outer->add_parameter(); transf_args->push_back(transf_p); } if (effect_info.io) { MS_LOG(DEBUG) << "Append IO monad to Fprop FuncGraph for Primitive " << primitive->ToString(); auto transf_p = outer->add_parameter(); transf_args->push_back(transf_p); } } template void KPrim::TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const T ¤t_primal_fg, const FuncGraphPtr &outer, std::vector *const transf_args) { MS_EXCEPTION_IF_NULL(mng); TransformNormalArgs(mng, bprop_fg, outer, transf_args); constexpr size_t need_filter_size = 2; auto bprop_fg_param_size = bprop_fg->parameters().size() - need_filter_size; // current_primal_fg may have extra parameters after AutoMonad const auto ¤t_primal_fg_params = current_primal_fg->parameters(); if (bprop_fg_param_size < current_primal_fg_params.size()) { for (auto i = bprop_fg_param_size; i < current_primal_fg_params.size(); ++i) { auto p = current_primal_fg_params[i]; MS_EXCEPTION_IF_NULL(p); // extra parameters should be Monad. if (!HasAbstractMonad(p)) { continue; } MS_LOG(DEBUG) << "Function " << current_primal_fg->ToString() << ", has extra monad parameter: " << p->DebugString() << ", abstract: " << p->abstract()->ToString(); TraceGuard trace_guard(std::make_shared(p->debug_info())); auto transf_p = outer->add_parameter(); // See also Notes on extra_node of BuildOutput. // Notes: No need to replace p with transf_p as the only use of p is here. // If extra_node in bprop_fg use p as free variable, a replacement of p is required here. // This replacement will make the usage of p in current_primal_fg got replaced with transf_p // of outer. outer will be released after it is being cloned to fprop_fg, so the func_graph_ // in transf_p will be nullptr. // So the RULE is DONT tamper the current_primal_fg; transf_args->push_back(transf_p); } } if (transf_args->size() != current_primal_fg_params.size()) { MS_EXCEPTION(TypeError) << "Function " << current_primal_fg->ToString() << ", The number of parameter of this primal function is " << current_primal_fg_params.size() << ", but the number of parameters of bprop is " << bprop_fg_param_size; } } void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); bool check_bprop_flag = context->get_param(MS_CTX_CHECK_BPROP_FLAG); // Skip checking if check_bprop not set if (!check_bprop_flag) { return; } // bprop_fg has been checked in caller auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops"); MS_EXCEPTION_IF_NULL(check_bprop_class); auto check_bprop = bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared(prim_to_check))}); std::vector inputs; inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); constexpr int primitive_size = 1; constexpr int brprop_offset_size = 2; (void)inputs.insert(inputs.begin() + primitive_size, bprop_fg->parameters().begin(), bprop_fg->parameters().end() - brprop_offset_size); AnfNodePtr params = bprop_fg->NewCNode(inputs); inputs.clear(); inputs.push_back(check_bprop); inputs.push_back(bprop_fg->output()); inputs.push_back(params); AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs); bprop_fg->set_output(bprop_out); } FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg) { MS_EXCEPTION_IF_NULL(bprop_fg); // primal_fg is FuncGraph just after convert. Refer ConvertCellObjToFuncGraph. // current_primal_fg is specalized and AutoMoaded primal_fg; auto primal_fg = bprop_fg->transforms().find("primal")->second.func_graph(); auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr, {}, {}); if (expanded_fg == nullptr) { MS_LOG(EXCEPTION) << "Failed convert " << primal_fg->ToString() << " Cell bprop function to K expanded func graph. NodeInfo: " << trace::GetDebugInfo(primal_fg->debug_info()); } return expanded_fg; } FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { auto prim = GetValueNode(value_node); MS_EXCEPTION_IF_NULL(prim); auto &node_users = resources->manager()->node_users(); auto &users = node_users[value_node]; auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair &user) -> bool { return IsPrimitiveCNode(user.first, prim); }); if (cnode == users.end()) { MS_LOG(EXCEPTION) << "Fail to find cnode."; } auto inputs_num = cnode->first->cast()->size() - 1; auto func_graph = std::make_shared(); std::vector outputs; auto bprop_cut = std::make_shared("bprop_cut"); bprop_cut->CopyHookFunction(prim); auto cell_id = GetValue(prim->GetAttr("cell_id")); if (cell_id != "") { (void)bprop_cut->AddAttr("cell_hook", MakeValue(true)); (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id)); } outputs.push_back(NewValueNode(bprop_cut)); for (size_t i = 0; i < inputs_num; ++i) { auto param = func_graph->add_parameter(); outputs.push_back(param); } auto p1 = func_graph->add_parameter(); auto p2 = func_graph->add_parameter(); outputs.push_back(p1); outputs.push_back(p2); func_graph->set_output(func_graph->NewCNode(outputs)); return func_graph; } FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { auto prim = value_node->value()->cast(); MS_EXCEPTION_IF_NULL(prim); auto &node_users = resources->manager()->node_users(); auto &users = node_users[value_node]; auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair &user) -> bool { return IsPrimitiveCNode(user.first, prim); }); if (cnode == users.end()) { MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString(); } auto inputs_num = cnode->first->cast()->inputs().size() - 1; auto effect_info = GetPrimEffectInfo(prim); // Don't add U or IO monad parameters as it will be added later. size_t monad_params_size = 0; if (effect_info.memory) { monad_params_size++; } if (effect_info.io) { monad_params_size++; } if (inputs_num < monad_params_size) { MS_LOG(EXCEPTION) << "Arguments number should be greater than or equal to " << monad_params_size << ", but the CNode is: " << cnode->first->DebugString(); } inputs_num -= monad_params_size; auto func_graph = std::make_shared(); std::vector outputs; outputs.push_back(NewValueNode(prim::kPrimMakeTuple)); auto fake_bprop = std::make_shared("fake_bprop"); (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined.")); for (size_t i = 0; i < inputs_num; ++i) { // Mock params for inputs auto param = func_graph->add_parameter(); // Mock derivatives for each inputs outputs.push_back(func_graph->NewCNode({NewValueNode(fake_bprop), param})); } // mock params for out and dout (void)func_graph->add_parameter(); (void)func_graph->add_parameter(); func_graph->set_output(func_graph->NewCNode(outputs)); return func_graph; } } // namespace ad } // namespace mindspore