/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ #define MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ #include #include #include #include #include #include "ir/anf.h" #include "ir/meta_func_graph.h" #include "ir/func_graph_cloner.h" #include "pipeline/resource.h" #include "optimizer/ad/adjoint.h" #include "operator/ops.h" #include "debug/trace.h" namespace mindspore { namespace ad { struct PrimitiveTotalEqual { bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { if (t1->name() != t2->name()) { return false; } auto const &attrs1 = t1->attrs(); auto const &attrs2 = t2->attrs(); if (attrs1.size() != attrs2.size()) { return false; } for (auto &attr1 : attrs1) { if (!t2->HasAttr(attr1.first)) { return false; } if (!(*(attr1.second) == *(t2->GetAttr(attr1.first)))) { return false; } } return true; } }; using Registry = std::unordered_map; class KPrim; extern KPrim g_k_prims; class DFunctor; using DFunctorPtr = std::shared_ptr; // D Functor's rules to map closure object and morphisms. class DFunctor : public std::enable_shared_from_this { public: DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources); ~DFunctor() = default; // Map object in D category to K category. void MapObject(); // Map morphism in D category to K category. void MapMorphism(); FuncGraphPtr k_graph(); // Construct user defined k object. FuncGraphPtr KUserDefined(const FuncGraphPtr &primal); // Register functor objects to form a global view. void Init(bool is_top = false); bool IsInScope(const AnfNodePtr &node); // Clear resources. static void Clear(); private: // Map one morphism. AdjointPtr MapMorphism(const AnfNodePtr &morph); bool IsFreeMorphism(const AnfNodePtr &node); // Map morphism that's not attached to output. void MapFreeMorphism(); void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din); void BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env); void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint); AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv); AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv); // Map Anfnode object from D category to K category. AnfNodePtr MapToK(const AnfNodePtr &primal); // Map FuncGraph object from D category to K category. AnfNodePtr MapToK(const FuncGraphPtr &primal); // MapObject impls. void MapFvObject(); void MapValueObject(); void MapParamObject(); // Find adjoint with its primary k. AdjointPtr FindAdjoint(const AnfNodePtr &primal); // Broadcast stop flags. void BroadCastStopFlag(); bool AllReferencesStopped(const CNodePtr &node); // Update k hole with adjoint_definition, only applied in recursive case. void UpdateAdjoint(const AdjointPtr &adjoint_definition); void CallDoutHoleOnTape(); std::unordered_map anfnode_to_adjoin_; // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer. std::unordered_map anfnode_to_adjoin_indirect_fv_; FuncGraphPtr primal_graph_; // K object for primal_graph_; FuncGraphPtr k_graph_; // The Backprop part of k_graph_. FuncGraphPtr tape_; // Dout parameter for primal_graph_. AnfNodePtr dout_; pipeline::ResourceBasePtr resources_; // Cut off stopped objects in category D. bool need_cut_; bool is_top_; static std::unordered_map> func_graph_to_functor_; static std::unordered_map anfnode_to_adjoin_definition_; static FuncGraphSet scope_; }; // D Functor's rules to map primitive object. class KPrim { public: KPrim() = default; ~KPrim() = default; FuncGraphPtr KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim); FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop); void clear() { bprop_registry_meta_.clear(); bprop_registry_.clear(); } private: FuncGraphPtr GetBprop(const PrimitivePtr &prim); FuncGraphPtr GetFprop(const PrimitivePtr &prim); FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); // Given a bprop rule, do the K mapping. template FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g); AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, std::vector *const transf_args); void CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check); Registry bprop_registry_; std::unordered_map bprop_registry_meta_; }; template FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { MS_EXCEPTION_IF_NULL(primal); MS_EXCEPTION_IF_NULL(bprop_fg); CheckBprop(bprop_fg, primal->ToString()); auto debug_info = std::make_shared(); debug_info->set_name(primal->ToString()); auto cloned_bprop_fg = BasicClone(bprop_fg); MS_EXCEPTION_IF_NULL(cloned_bprop_fg); cloned_bprop_fg->debug_info()->set_name(""); cloned_bprop_fg->debug_info()->set_trace_info(std::make_shared(debug_info)); AnfNodePtr bout = BuildOutput(cloned_bprop_fg); cloned_bprop_fg->set_output(bout); TraceManager::DebugTrace(std::make_shared(debug_info)); auto outer = std::make_shared(); (void)outer->transforms().emplace("primal", FuncGraphTransform(primal)); outer->set_output(NewValueNode(kNone)); TraceManager::EndTrace(); auto mng = Manage({cloned_bprop_fg, outer}, false); // Make sure (out, dout) provided. if (cloned_bprop_fg->parameters().size() < 2) { MS_LOG(EXCEPTION) << "Primitive or Cell " << primal->ToString() << " bprop requires out and dout at least, but only got " << cloned_bprop_fg->parameters().size() << " params. NodeInfo: " << trace::GetDebugInfo(cloned_bprop_fg->debug_info()); } // In a bprop definition, the last two param should be out and dout. auto dout = cloned_bprop_fg->parameters()[cloned_bprop_fg->parameters().size() - 1]; auto out_param = cloned_bprop_fg->parameters()[cloned_bprop_fg->parameters().size() - 2]; std::vector transf_args; TransformArgs(mng, cloned_bprop_fg, outer, &transf_args); TraceManager::DebugTrace(std::make_shared(dout->debug_info())); (void)transf_args.insert(transf_args.begin(), NewValueNode(primal)); auto out_value = outer->NewCNode(transf_args); TraceManager::EndTrace(); (void)mng->Replace(out_param, out_value); TraceManager::DebugTrace(std::make_shared(out_param->debug_info())); auto new_dout = cloned_bprop_fg->add_parameter(); (void)mng->Replace(dout, new_dout); // We remove all parameters except new_dout. std::vector newBpropParams = {new_dout}; cloned_bprop_fg->set_parameters(newBpropParams); TraceManager::EndTrace(); outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)})); return BasicClone(outer); } } // namespace ad } // namespace mindspore #endif // MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_