/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H #define MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H #include #include #include #include #include #include #include "session/session_basic.h" #include "session/kernel_graph.h" #include "kernel/kernel.h" #include "session/session_factory.h" namespace mindspore { namespace session { enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 }; class AscendSession : public SessionBasic { public: AscendSession() { final_graph_id_ = kInvalidGraphId; } ~AscendSession() override = default; void Init(uint32_t device_id) override { SessionBasic::Init(device_id); context_ = std::make_shared(kAscendDevice, device_id); } GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; void BuildGraph(GraphId) override; void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, const std::vector &tensors_mask) override; py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors) override; // set parameters of final graph GraphId SetFinalGraphInput(const std::vector &args) override; // set output of final graph void SetFinalGraphOutput(const BaseRef &output) override; // insert switch and set the relative active ops void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override; // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter void SetChildGraphInput(GraphId g, const VectorRef &args) override; // get graph id in child graphs by ME front anf node pointer GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; // get graph id of final graph GraphId GetFinalRunGraph() const override { return final_graph_id_; } // insert active to graph void SetActive(GraphId, GraphId) override; private: void InitRuntimeResource(); void SelectKernel(const KernelGraph &kernel_graph) const; void HardwareOptimize(const std::shared_ptr &kernel_graph) const; void AdjustKernel(const std::shared_ptr &kernel_graph) const; void RunOpAdjustKernel(const std::shared_ptr &kernel_graph) const; void AssignStream(const std::shared_ptr &kernel_graph) const; void BuildKernel(const std::shared_ptr &kernel_graph) const; void MemoryAlloc(KernelGraph *kernel_graph) const; void RunOpMemoryAlloc(const std::vector &input_tensors, KernelGraph *kernel_graph) const; void GenerateTaskInfo(const std::shared_ptr &kernel_graph) const; void LoadTask(const std::shared_ptr &kernel_graph) const; void ExecTask(const std::shared_ptr &kernel_graph) const; void Dump(const std::shared_ptr &kernel_graph) const; // below functions are used for run op void RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) const; void RunOpExecTask(const std::shared_ptr &kernel_graph) const; size_t SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index); size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index); size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index); // merge execution order list of child graphs void MergeGraphExecOrder(); // insert assion op to sync data bettween different graphs void InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); // insert mutiple assigns to graph void InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); // insert active op to graph void InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream); // get execute index of graph size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph); // handle condition graph from vm void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id); // Get graph by graph id ,if not exist return null ptr KernelGraphPtr GetGraph(GraphId graph_id); // set child graph parameter if front arg is a anf void SetChildGraphParameter(const AnfNodePtr &front_anf, const AnfNodePtr &backend_parameter); // set child graph parameter if front arg is a tensor void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, const AnfNodePtr &backend_parameter); // update the execution order of all child graphs void UpdateGraphOrder(GraphId to_graph); // handle switch when merge void MergeSwitchCompile(); // get graph order vector by graph id std::vector &GetGraphOrder(GraphId final_graph_id); // get graph order type vector by graph id std::vector &GetGraphOrderType(GraphId final_graph_id); // copy output of if and else void CopyOutputOfIf(GraphId false_graph_id); // check if graph cache exist bool GraphCacheExist(const GraphInfo &graph_info) const; // member variables // key is final_graph_id,value is child graph execute order of final graph std::unordered_map> graph_execute_orders_; // key is final_graph_id,value is the graph types of child graphs std::unordered_map> graph_order_types_; // record condition graph of while std::unordered_map while_condition_graphs_; // record all conditions std::unordered_map> switches_; std::unordered_map condition_output_; // final_graph_id is used in every root graph has it's own session situation GraphId final_graph_id_; }; MS_REG_SESSION(kAscendDevice, AscendSession); } // namespace session } // namespace mindspore #endif // MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H