/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H #define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H #include #include #include #include #include #include #include #include #include #include "backend/session/session_basic.h" #include "backend/session/kernel_graph.h" #include "backend/kernel_compiler/kernel.h" #include "backend/session/session_factory.h" #include "backend/session/ascend_control_parser.h" #include "runtime/context.h" namespace mindspore { namespace session { enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 }; class AscendSession : public SessionBasic { public: AscendSession() { final_graph_id_ = kInvalidGraphId; } ~AscendSession() { if (rt_context_ != nullptr) { auto ret = rtCtxDestroy(rt_context_); if (ret != RT_ERROR_NONE) { MS_EXCEPTION(DeviceProcessError) << "Call rtCtxDestroy, ret[" << ret << "]"; } rt_context_ = nullptr; } } void Init(uint32_t device_id) override { InitDevice(kAscendDevice, device_id); auto ret = rtCtxCreate(&rt_context_, 0, device_id); if (ret != RT_ERROR_NONE) { MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast(ret) << "]"; } ret = rtCtxSetCurrent(rt_context_); if (ret != RT_ERROR_NONE) { MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; } } // get graph id in child graphs by ME front anf node pointer GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; // get graph id of final graph GraphId GetFinalRunGraph() const override { return final_graph_id_; } protected: GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; GraphId CompileGraphImpl(NotNull func_graph) override; GraphId CompileGraphImpl(NotNull func_graph, const std::vector &inputs) override; void RunGraphImpl(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; void BuildGraphImpl(GraphId) override; void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, const std::vector &tensors_mask) override; void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, VectorRef *outputs) override; private: // compile child graph when session have multiple child graphs void CompileChildGraph(const KernelGraphPtr &child_graph); void RecurseSetSummaryNodes(KernelGraph *graph, std::map> *summary); void SetSummaryNodes(KernelGraph *graph) override; void InitRuntimeResource(); void SelectKernel(const KernelGraph &kernel_graph) const; void HardwareOptimize(const std::shared_ptr &kernel_graph) const; void AdjustKernel(const std::shared_ptr &kernel_graph) const; void RunOpAdjustKernel(const std::shared_ptr &kernel_graph) const; void AssignStream(NotNull kernel_graph) const; void BuildKernel(const std::shared_ptr &kernel_graph) const; void MemoryAlloc(KernelGraph *kernel_graph) const; void RunOpMemoryAlloc(const ValuePtr &pre_output_value, const std::vector &input_tensors, KernelGraph *kernel_graph) const; void RunOpMemoryClear(const KernelGraph *kernel_graph) const; void Load(const std::shared_ptr &kernel_graph) const; void Execute(const std::shared_ptr &kernel_graph, bool is_task) const; void Dump(const std::shared_ptr &kernel_graph) const; void DumpAllGraphs(const std::vector &all_graphs); void LoadTensor(const std::shared_ptr &kernel_graph) const; // below functions are used for run op void RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) const; static void BackendOptimization(const std::vector &all_graphs); static void LinkChildGraphs(NotNull graph); void RootGraphExecutorValidate(NotNull graph); // merge execution order list of child graphs void MergeGraphExecOrder(); // insert assion op to sync data bettween different graphs void InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); // get graph order vector by graph id const std::vector &GetGraphOrder(GraphId final_graph_id) const; // get graph order type vector by graph id const std::vector &GetGraphOrderType(GraphId final_graph_id) const; // check if graph cache exist bool GraphCacheExist(const GraphInfo &graph_info) const; // insert all assign to child graph void InsertAllAssigns(); // sync intial tensors' data to device void SyncInitialTenosrToDevice(); void SetFinalGraphSummaryFlag(const std::shared_ptr &kernel_graph); // create parameter to receive data from multiple branch output void CreateMultiBranchOutput(NotNull graph, NotNull *> memo); void SelectKernel(NotNull root_graph); void RecurseSelectKernelInfo(NotNull graph, NotNull *> const memo, size_t *const raise_precision_count, size_t *const reduce_precision_count) const; void IrFusionPass(const NotNull graph, NotNull *> memo); void HardwareOptimize(const NotNull graph, NotNull *> memo) const; void AssignStaticMemory(const NotNull graph, NotNull *> memo) const; void UpdateRefOutputMap(const NotNull graph, NotNull *> memo) const; // key is final_graph_id,value is child graph execute order of final graph std::unordered_map> graph_execute_orders_; // key is final_graph_id,value is the graph types of child graphs std::unordered_map> graph_order_types_; // share parameters std::vector> assigns_; // initial tensors, these tensor will sync data to device before run graph std::map, tensor::TensorPtr> initial_tenosrs_; // final_graph_id is used in every root graph has it's own session situation GraphId final_graph_id_; // ascend runtime context rtContext_t rt_context_{nullptr}; }; MS_REG_SESSION(kAscendDevice, AscendSession); } // namespace session } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H