/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ #define MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ #define DRAW_GE_GRAPH #include #include #include #include #include #include #include #include #include #include "ir/anf.h" #include "ir/func_graph.h" #include "transform/util.h" #include "ir/meta_tensor.h" #include "transform/df_graph_manager.h" #include "utils/config_manager.h" #include "transform/op_declare.h" #include "graph/operator_reg.h" #ifdef OPEN_SOURCE #include "ge/client/ge_api.h" #else #include "external/ge/ge_api.h" #endif #include "graph/tensor.h" #include "ops/all_ops.h" namespace mindspore { namespace transform { class OpAdapterDesc { public: OpAdapterDesc() : train_(nullptr), infer_(nullptr) {} OpAdapterDesc(const OpAdapterPtr &train, const OpAdapterPtr &infer) : train_(train), infer_(infer) {} explicit OpAdapterDesc(const OpAdapterPtr &common) : train_(common), infer_(common) {} OpAdapterDesc(const OpAdapterDesc &desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; } OpAdapterDesc(OpAdapterDesc &&desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; desc.train_ = nullptr; desc.infer_ = nullptr; } ~OpAdapterDesc() = default; OpAdapterPtr Get(bool train) const { return train ? train_ : infer_; } OpAdapterDesc &operator=(const OpAdapterDesc &desc) { if (this != &desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; } return *this; } OpAdapterDesc &operator=(OpAdapterDesc &&desc) { if (this != &desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; desc.train_ = nullptr; desc.infer_ = nullptr; } return *this; } private: OpAdapterPtr train_; OpAdapterPtr infer_; }; using OpAdapterDescPtr = std::shared_ptr; using TensorOrderMap = std::map>; class DfGraphConvertor { public: explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) : anf_graph_(anf_graph), df_graph_(std::make_shared(anf_graph_->ToString())) { #if (!defined ENABLE_GE) || (defined ENABLE_INFER) auto it_training = anf_graph->flags().find("training"); if (it_training != anf_graph->flags().end()) { training_ = it_training->second; } else { training_ = false; } #else training_ = ENABLE_TRAIN; #endif auto it_distribute = anf_graph->flags().find("broadcast_flag"); if (it_distribute != anf_graph->flags().end()) { ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::DISTRIBUTION); distribute_ = it_distribute->second; } else { ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::ONE_DEVICE); distribute_ = false; } MS_LOG(INFO) << "Create DfGraphConvertor with training: " << training_ << ", distribute: " << distribute_; } ~DfGraphConvertor() {} static void RegisterAdapter(const std::string &name, OpAdapterPtr adpt) { get_adpt_map()[name] = std::make_shared(adpt); } static void RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) { get_adpt_map()[name] = std::make_shared(train_adpt, infer_adpt); } void DrawComputeGraph(const std::string &name) { std::ofstream fout(name); if (!fout.is_open()) { MS_LOG(ERROR) << "Open file '" << name << "' failed!"; return; } fout << compute_sout_.str(); fout.close(); } void DrawInitGraph(const std::string &name) { std::ofstream fout(name); if (!fout.is_open()) { MS_LOG(ERROR) << "Open file '" << name << "' failed!"; return; } fout << init_sout_.str(); fout.close(); } void DrawSaveCheckpointGraph(const std::string &name) { std::ofstream fout(name); if (!fout.is_open()) { MS_LOG(ERROR) << "Open file '" << name << "' failed!"; return; } fout << checkpoint_sout_.str(); fout.close(); } DfGraphConvertor &ConvertAllNode(); DfGraphConvertor &BuildGraph(); DfGraphConvertor &InitParam(const TensorOrderMap &tensors); DfGraphConvertor &GenerateCheckpointGraph(); DfGraphConvertor &GenerateBroadcastGraph(const TensorOrderMap &tensors); void InitParamWithData(const TensorOrderMap &tensors); void SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node); void SetupBroadcast(const std::shared_ptr &broadcast, const std::vector &broadcast_desc, const DfGraphPtr &broadcast_graph, std::vector broadcast_input); void MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it); void SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector *init_input); void DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it); DfGraphPtr GetComputeGraph(); DfGraphPtr GetInitGraph(); DfGraphPtr GetSaveCheckpointGraph(); DfGraphPtr GetBroadcastGraph(); static OpAdapterPtr FindAdapter(const std::string &op_name, bool train = false); static OpAdapterPtr FindAdapter(AnfNodePtr node, bool train = false); int ErrCode() const { return static_cast(error_); } static std::unordered_map &get_adpt_map(); bool is_training() const { return training_; } void set_training(bool is_training) { training_ = is_training; } protected: void InitLoopVar(std::vector *init_input); private: std::ostringstream compute_sout_; std::ostringstream init_sout_; std::ostringstream checkpoint_sout_; std::ostringstream restore_checkpoint_sout_; std::unordered_map op_draw_name_; AnfNodePtr TraceTupleGetItem(const CNodePtr &node, unsigned int *index); AnfNodePtr TraceMakeTuple(const CNodePtr &node, unsigned int index); AnfNodePtr TraceDepend(const CNodePtr &node); OutHandler TraceRealOp(AnfNodePtr node); OutHandler GetHandler(const AnfNodePtr &node, const std::stack &index_stack, AnfNode *const draw_index); OperatorPtr Convert(AnfNodePtr node); OperatorPtr ConvertCNode(CNodePtr node); std::vector ConvertDependNode(AnfNodePtr node); AnfNodePtr GetRealOpNode(AnfNodePtr node); std::vector GetDependNodes(const AnfNodePtr &node); OperatorPtr ConvertParameter(AnfNodePtr node); Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); OperatorPtr ConvertValueNode(ValueNodePtr node); void ConvertTupleGetItem(const CNodePtr node); void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, const std::shared_ptr> &src_ops_list, const std::shared_ptr> &dst_ops_list); bool GetControlDependList(const CNodePtr &node, const std::shared_ptr> &src_ops_list, const std::shared_ptr> &dst_ops_list); void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node); void ConvertControlDependNode(const CNodePtr node); void ConvertMakeTuple(const CNodePtr node); bool CheckCNode(const std::string &name, const CNodePtr node); void TraceOutput(AnfNodePtr node); void TraceOutputFromParameter(const AnfNodePtr &anf_out); void TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out); void SetNodeInput(AnfNodePtr node); void SetOpControlInput(const AnfNodePtr node); void UpdateOpDesc(AnfNodePtr node); void BuildSaveCheckpointGraph(); void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; void AddGraphConstInput(const OperatorPtr &op); std::shared_ptr anf_graph_{nullptr}; std::shared_ptr df_graph_{nullptr}; std::shared_ptr init_graph_{nullptr}; std::shared_ptr save_ckp_graph_{nullptr}; std::shared_ptr restore_ckp_graph_{nullptr}; std::shared_ptr broadcast_graph_{nullptr}; std::unordered_map op_cache_; std::unordered_map> control_depend_cache_; /* record "tuple_getitem"<->"out_handler" mapping */ std::unordered_map out_handle_cache_; /* record "make_tuple"<->"out_handler vector" mapping */ std::unordered_map>> tuple_out_handle_cache_; std::unordered_map params_; std::unordered_map vars_; std::vector> graph_outputs_; std::vector graph_const_inputs_; std::vector init_ops_; std::vector broadcast_ops_; OperatorPtr dataset_iter_getnext_; Status error_ = SUCCESS; bool training_ = false; bool distribute_ = false; }; } // namespace transform } // namespace mindspore #endif // MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_