You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_basic.h 20 kB

5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
5 years ago
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H
  17. #define MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H
  18. #include <vector>
  19. #include <string>
  20. #include <utility>
  21. #include <memory>
  22. #include <map>
  23. #include <set>
  24. #include "utils/hash_map.h"
  25. #include "backend/common/session/session_context.h"
  26. #include "backend/common/session/kernel_graph.h"
  27. #include "backend/common/session/anf_runtime_algorithm.h"
  28. #include "include/common/utils/anfalgo.h"
  29. #include "ir/anf.h"
  30. #include "ir/tensor.h"
  31. #include "utils/any.h"
  32. #include "include/common/utils/contract.h"
  33. #include "runtime/device/kernel_info.h"
  34. #include "utils/ms_context.h"
  35. #include "runtime/device/bucket.h"
  36. #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64)
  37. #include "debug/debugger/debugger.h"
  38. #endif
  39. #include "runtime/hardware/device_context.h"
  40. #include "backend/common/session/pynative_task_manager.h"
  41. namespace mindspore {
  42. namespace runtime {
  43. class GraphCompiler;
  44. } // namespace runtime
  45. } // namespace mindspore
  46. namespace mindspore {
  47. using GraphId = uint32_t;
  48. using GraphInfo = std::string;
  49. const char kSessionBasic[] = "SessionBasic";
  50. namespace session {
  51. using CallBackFunc = uint32_t (*)(uint32_t graph_id,
  52. const std::map<std::string, mindspore::tensor::TensorPtr> &params_list);
  53. using AnyList = std::vector<Any>;
  54. using AnyListPtr = std::shared_ptr<AnyList>;
  55. struct OpRunInfo {
  56. bool is_gradient_out = false;
  57. std::string op_name;
  58. Primitive *primitive;
  59. AbstractBasePtr abstract;
  60. bool is_dynamic_shape = false;
  61. bool is_auto_mixed_precision = false;
  62. bool lazy_build = false;
  63. std::string next_op_name;
  64. #if defined(__APPLE__)
  65. int next_input_index = 0;
  66. #else
  67. size_t next_input_index = 0;
  68. #endif
  69. std::string graph_info;
  70. std::vector<int64_t> tensor_mask;
  71. std::vector<tensor::TensorPtr> input_tensors;
  72. std::string device_target = "Unknown";
  73. };
  74. struct InputTensorInfo {
  75. std::vector<tensor::TensorPtr> input_tensors;
  76. std::vector<int64_t> input_tensors_mask;
  77. std::set<KernelWithIndex> input_kernel;
  78. };
  79. struct OutputTensorInfo {
  80. tensor::TensorPtr output_stub_tensor;
  81. bool is_weight;
  82. };
  83. struct GraphOutputInfo {
  84. VectorRef *graph_outputs;
  85. std::map<KernelWithIndex, std::vector<std::vector<size_t>>> output_indexes;
  86. std::vector<tensor::TensorPtr> graph_output_tensors;
  87. };
  88. class Executor;
  89. class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
  90. public:
  91. SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {
  92. #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64)
  93. debugger_ = nullptr;
  94. #endif
  95. }
  96. virtual void Init(uint32_t device_id) { device_id_ = device_id; }
  97. void InitExecutor(const std::string &device_name, uint32_t device_id);
  98. virtual void SyncStream() const {}
  99. virtual ~SessionBasic() { summary_callback_ = nullptr; }
  100. GraphId CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs);
  101. GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph);
  102. void BuildGraph(GraphId graphId);
  103. void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
  104. void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
  105. void RunOp(OpRunInfo *, VectorRef *outputs);
  106. void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
  107. #ifndef ENABLE_SECURITY
  108. virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
  109. #endif
  110. bool CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph);
  111. std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
  112. DeviceAddressType device_target = DeviceAddressType::kUnknown,
  113. bool common_opt = true);
  114. std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
  115. std::vector<KernelGraphPtr> *all_out_graph,
  116. DeviceAddressType device_target);
  117. void SetInputNodeUsage(const KernelGraphPtr &graph, const FuncGraphManagerPtr &manager);
  118. CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
  119. mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
  120. CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph);
  121. // get graph id in child graphs by ME front anf node pointer
  122. virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const;
  123. virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
  124. void AssignParamKey(const KernelGraphPtr &kernel_graph);
  125. void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const);
  126. bool IsGetNextGraph(const std::shared_ptr<KernelGraph> &kernel_graph, std::string *channel_name);
  127. virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs,
  128. std::string *error_msg) const {
  129. return true;
  130. }
  131. void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs,
  132. std::vector<std::string> *inputs_name) const;
  133. void GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs,
  134. std::vector<std::string> *outputs_name) const;
  135. std::vector<tensor::TensorPtr> GetInputNeedLockTensors(const GraphId &graph_id,
  136. const std::vector<tensor::TensorPtr> &inputs);
  137. // Get graph by graph id, if not exist return null ptr
  138. KernelGraphPtr GetGraph(GraphId graph_id) const;
  139. void ClearGraph();
  140. // create a single run op graph
  141. std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info,
  142. const std::vector<tensor::TensorPtr> &input_tensors,
  143. const std::vector<int64_t> &tensors_mask, bool is_ascend = false);
  144. void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask,
  145. std::vector<tensor::TensorPtr> *input_tensors) const;
  146. void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const;
  147. static void RunOpHideNopNode(const KernelGraphPtr &kernel_graph);
  148. virtual void ReportWarningMessage() {}
  149. virtual void ReportErrorMessage() {}
  150. virtual void SetThreadContext() {}
  151. #ifdef ENABLE_DEBUGGER
  152. // set debugger
  153. void SetDebugger() {
  154. debugger_ = Debugger::GetInstance();
  155. auto ms_context = MsContext::GetInstance();
  156. MS_EXCEPTION_IF_NULL(ms_context);
  157. MS_EXCEPTION_IF_NULL(debugger_);
  158. debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
  159. }
  160. #endif
  161. private:
  162. CNodePtr CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph);
  163. std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);
  164. std::vector<AnfNodePtr> CreateValueNode(const CNodePtr &cnode, KernelGraph *graph);
  165. void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs);
  166. std::vector<AnfNodePtr> CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph);
  167. void GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) const;
  168. void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
  169. mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
  170. std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph);
  171. void ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, const std::vector<AnfNodePtr> &real_inputs);
  172. void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
  173. const FuncGraphManagerPtr &front_func_graph_manager,
  174. const std::shared_ptr<KernelGraph> &backend_graph);
  175. std::string AddPartialParametersMap(const AnfNodePtr &partial_node);
  176. void GetParameterIndex(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
  177. std::map<AnfNodePtr, size_t> *parameter_index);
  178. void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors,
  179. VectorRef *const outputs,
  180. std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes);
  181. void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count);
  182. void GetForwardOpOutputRefCount(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
  183. std::map<std::string, size_t> *forward_op_output_tensor_id);
  184. void ReleaseForwardOpOutput(const std::vector<tensor::TensorPtr> &input_tensors,
  185. std::map<std::string, size_t> *forward_op_output_tensor_id);
  186. void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count,
  187. std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map);
  188. void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
  189. const std::map<KernelWithIndex, size_t> &ref_count,
  190. std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map,
  191. GraphOutputInfo *const graph_output_info);
  192. protected:
  193. friend class Executor;
  194. friend class CompileNodesTask;
  195. friend class CompileGraphTask;
  196. friend class BuildGraphTask;
  197. friend class RunGraphTask;
  198. friend class RunOpTask;
  199. friend class RunOpsInGraphTask;
  200. friend class mindspore::runtime::GraphCompiler;
  201. virtual bool IsSupportSummary() { return true; }
  202. virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
  203. VectorRef *outputs,
  204. std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
  205. KernelMapTensor *node_to_tensor);
  206. // When the device address of the node is used as the output of the graph, the device address will be passed
  207. // to the output tensor, and the output node will recreate a new device address. This third parameter records
  208. // the relationship between the new and old device address.
  209. virtual void UpdateOutputTensors(const VectorRef *outputs,
  210. const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
  211. std::map<DeviceAddressPtr, DeviceAddressPtr> *);
  212. virtual void UnifyMindIR(const KernelGraphPtr &graph);
  213. virtual void FinalOptimize(const KernelGraphPtr &graph) const;
  214. virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; }
  215. virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr>) { return kInvalidGraphId; }
  216. virtual void BuildGraphImpl(GraphId) {}
  217. virtual void PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph,
  218. const std::vector<tensor::TensorPtr> &inputs, VectorRef *const outputs) {
  219. MS_EXCEPTION_IF_NULL(kernel_graph);
  220. MS_EXCEPTION_IF_NULL(outputs);
  221. MS_LOG(INFO) << "Call default PreExecuteGraph with input size: " << inputs.size();
  222. }
  223. virtual void PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph,
  224. const std::vector<tensor::TensorPtr> &inputs, VectorRef *const outputs) {
  225. MS_EXCEPTION_IF_NULL(kernel_graph);
  226. MS_EXCEPTION_IF_NULL(outputs);
  227. MS_LOG(INFO) << "Call default PostExecuteGraph with input size: " << inputs.size();
  228. }
  229. virtual void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); }
  230. void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
  231. virtual KernelGraphPtr BuildOpImpl(const OpRunInfo & /* op_run_info */, const GraphInfo & /* graph_info */,
  232. const std::vector<tensor::TensorPtr> & /* input_tensors */,
  233. const std::vector<int64_t> & /* tensors_mask */) {
  234. return nullptr;
  235. }
  236. virtual void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
  237. std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
  238. const std::vector<int64_t> &tensors_mask) {}
  239. virtual void RunOpImplOrigin(const GraphInfo &graph_info, OpRunInfo *op_run_info,
  240. std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
  241. const std::vector<int64_t> &tensors_mask) {}
  242. void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
  243. void ProcessInputTensorsForHeterogeneous(const std::string &cur_target,
  244. const std::vector<tensor::TensorPtr> &input_tensors);
  245. virtual void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index,
  246. const std::vector<tensor::TensorPtr> &graph_inputs,
  247. const std::map<KernelWithIndex, size_t> &cnode_refcount) {}
  248. #ifndef ENABLE_SECURITY
  249. virtual void SetSummaryNodes(KernelGraph *graph);
  250. #endif
  251. void LoadInputs(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs_const) {
  252. MS_LOG(INFO) << "Status record: start load input. graph id: " << graph_id;
  253. auto kernel_graph = GetGraph(graph_id);
  254. MS_EXCEPTION_IF_NULL(kernel_graph);
  255. if (!kernel_graph->executable()) {
  256. return;
  257. }
  258. LoadInputData(kernel_graph, inputs_const);
  259. MS_LOG(INFO) << "Status record: end load input. graph id: " << graph_id;
  260. }
  261. virtual void ExecuteAllTaskInQueue() {}
  262. virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
  263. const std::vector<tensor::TensorPtr> &inputs_const) const {
  264. MS_EXCEPTION_IF_NULL(kernel_graph);
  265. MS_LOG(INFO) << "Call default LoadInputData with input size: " << inputs_const.size();
  266. }
  267. void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
  268. const std::vector<tensor::TensorPtr> &input_tensors,
  269. std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) const;
  270. void UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph, OpRunInfo *op_run_info) const;
  271. #ifndef ENABLE_SECURITY
  272. void Summary(KernelGraph *graph);
  273. #endif
  274. // create graph output for RunOp
  275. void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph);
  276. CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph);
  277. // Generate graph info for a single op graph
  278. GraphInfo GetSingleOpGraphInfo(const CNodePtr &kernel, const std::vector<tensor::TensorPtr> &input_tensors);
  279. OpRunInfo GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInfo &graph_info, const InputTensorInfo &tensor_info,
  280. GraphOutputInfo *const graph_output_info);
  281. tensor::TensorPtr GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index);
  282. tensor::TensorPtr GetParameterOutputTensor(const AnfNodePtr &node,
  283. const std::map<AnfNodePtr, size_t> &parameter_index,
  284. const std::vector<tensor::TensorPtr> &graph_inputs);
  285. tensor::TensorPtr GetCNodeOutputTensor(const KernelWithIndex &kernel_with_index,
  286. const std::map<KernelWithIndex, tensor::TensorPtr> &op_output);
  287. void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
  288. const std::map<AnfNodePtr, size_t> &parameter_index,
  289. const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info);
  290. tensor::TensorPtr GetOpInputTensorByIndex(const CNodePtr &cnode,
  291. const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
  292. const std::map<AnfNodePtr, size_t> &parameter_index,
  293. const std::vector<tensor::TensorPtr> &graph_inputs,
  294. InputTensorInfo *const input_tensor_info, size_t input_index);
  295. // create a new kernel graph and update the graph sum
  296. KernelGraphPtr NewKernelGraph();
  297. AnfNodePtr CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph);
  298. virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph);
  299. ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
  300. ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
  301. AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph);
  302. void AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph);
  303. void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter);
  304. AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list);
  305. virtual std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) { return nullptr; }
  306. void InitAllBucket(const KernelGraphPtr &graph, const device::DeviceContext *device_context = nullptr);
  307. void AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor);
  308. void ClearAllBucket(const GraphId &graph_id);
  309. std::vector<uint32_t> GetAllReduceSplitIndex();
  310. virtual std::string GetCommWorldGroup() { return std::string(); }
  311. void DumpGraphs(const std::vector<KernelGraphPtr> &graphs);
  312. #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
  313. void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const;
  314. void GetBatchElements(const AnfNodePtr &kernel_node) const;
  315. void InitPsWorker(const KernelGraphPtr &kernel_graph);
  316. #endif
  317. std::map<uint32_t, std::vector<std::shared_ptr<device::Bucket>>> bucket_map_;
  318. std::map<uint32_t, uint32_t> free_bucket_id_map_;
  319. mindspore::HashMap<GraphId, std::shared_ptr<KernelGraph>> graphs_;
  320. mindspore::HashMap<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
  321. mindspore::HashMap<FuncGraph *, KernelGraphPtr> front_backend_graph_map_;
  322. mindspore::HashMap<AnfNodePtr, AnfNodePtr> partial_parameters_map_;
  323. mindspore::HashMap<AnfNodePtr, std::string> partial_target_map_;
  324. mindspore::HashMap<AnfNodePtr, ParameterPtr> default_param_map_;
  325. std::shared_ptr<Context> context_;
  326. CallBackFunc summary_callback_;
  327. static GraphId graph_sum_;
  328. uint32_t device_id_;
  329. // rank id of physical device
  330. uint32_t rank_id_{0};
  331. std::shared_ptr<Executor> executor_;
  332. #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64)
  333. std::shared_ptr<Debugger> debugger_;
  334. #endif
  335. };
  336. using SessionPtr = std::shared_ptr<session::SessionBasic>;
  337. using NamedSummaryOutputs = std::map<std::string, std::pair<AnfNodePtr, int>>;
  338. } // namespace session
  339. void DumpGraphExeOrder(const std::string &file_name, const std::string &target_dir,
  340. const std::vector<CNodePtr> &execution_order);
  341. uint32_t GetRankId();
  342. } // namespace mindspore
  343. #endif // MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H