You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_basic.h 7.6 kB

adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H
  17. #define MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H
  18. #include <vector>
  19. #include <string>
  20. #include <unordered_map>
  21. #include <utility>
  22. #include <memory>
  23. #include <map>
  24. #include "utils/base_ref_extends.h"
  25. #include "backend/session/session_context.h"
  26. #include "backend/session/kernel_graph.h"
  27. #include "ir/anf.h"
  28. #include "ir/tensor.h"
  29. #include "utils/any.h"
  30. #include "utils/contract.h"
  31. #include "pipeline/pynative/pynative_execute.h"
  32. #include "runtime/device/kernel_info.h"
  33. #include "utils/ms_context.h"
  34. #ifdef ENABLE_DEBUGGER
  35. #include "debug/debugger/debugger.h"
  36. #endif
  37. namespace mindspore {
  38. using GraphId = uint32_t;
  39. using GraphInfo = std::string;
  40. namespace session {
  41. void ClearPythonParasMap();
  42. using CallBackFunc = uint32_t (*)(uint32_t graph_id,
  43. const std::map<std::string, mindspore::tensor::TensorPtr> &params_list);
  44. using AnyList = std::vector<Any>;
  45. using AnyListPtr = std::shared_ptr<AnyList>;
  46. using OpRunInfo = pynative::OpExecInfo;
  47. using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
  48. class SessionBasic {
  49. public:
  50. SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {
  51. #ifdef ENABLE_DEBUGGER
  52. debugger_ = nullptr;
  53. #endif
  54. }
  55. virtual void Init(uint32_t device_id) { device_id_ = device_id; }
  56. virtual ~SessionBasic() { summary_callback_ = nullptr; }
  57. virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
  58. virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
  59. // build graph, used to handle multiple child graphs
  60. virtual void BuildGraph(GraphId) {}
  61. virtual void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) = 0;
  62. virtual void BuildOp(const OpRunInfo &, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors,
  63. const std::vector<int> &tensors_mask) {}
  64. virtual py::tuple RunOp(const OpRunInfo &, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors) {
  65. return py::tuple();
  66. }
  67. virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
  68. void CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph);
  69. std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
  70. std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
  71. std::vector<KernelGraphPtr> *all_out_graph);
  72. CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph,
  73. std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
  74. CNodePtr CreateNewCNode(CNodePtr cnode, KernelGraph *graph);
  75. // get graph id in child graphs by ME front anf node pointer
  76. virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; }
  77. virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
  78. void AssignParamKey(const KernelGraphPtr &kernel_graph);
  79. void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const);
  80. virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs,
  81. std::string *error_msg) const {
  82. return true;
  83. }
  84. virtual void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const {}
  85. #ifdef ENABLE_DEBUGGER
  86. // set debugger
  87. void SetDebugger() {
  88. debugger_ = Debugger::GetInstance();
  89. auto ms_context = MsContext::GetInstance();
  90. MS_EXCEPTION_IF_NULL(ms_context);
  91. debugger_->Init(device_id_, ms_context->device_target());
  92. }
  93. #endif
  94. private:
  95. CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph);
  96. std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);
  97. std::vector<AnfNodePtr> CreateValueNode(const CNodePtr &cnode, KernelGraph *graph);
  98. void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs);
  99. protected:
  100. virtual void SetSummaryNodes(KernelGraph *graph);
  101. // Get graph by graph id ,if not exist return null ptr
  102. KernelGraphPtr GetGraph(GraphId graph_id) const;
  103. virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
  104. const std::vector<tensor::TensorPtr> &inputs_const) const;
  105. void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
  106. const std::vector<tensor::TensorPtr> &input_tensors) const;
  107. void Reorder(std::vector<CNodePtr> *node_list);
  108. void Summary(KernelGraph *graph);
  109. // create graph output for RunOp
  110. void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph);
  111. CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph);
  112. // create a single run op graph
  113. std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info,
  114. const std::vector<tensor::TensorPtr> &input_tensors,
  115. const std::vector<int> &tensors_mask);
  116. // trans BaseRef list to py::tuple
  117. BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref);
  118. // create a new kernel graph and update the graph sum
  119. KernelGraphPtr NewKernelGraph();
  120. std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph);
  121. virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
  122. ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
  123. ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
  124. AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
  125. void AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph);
  126. void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter);
  127. AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list);
  128. std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
  129. std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
  130. std::unordered_map<FuncGraphPtr, KernelGraphPtr> front_backend_graph_map_;
  131. std::shared_ptr<Context> context_;
  132. CallBackFunc summary_callback_;
  133. static GraphId graph_sum_;
  134. uint32_t device_id_;
  135. #ifdef ENABLE_DEBUGGER
  136. std::shared_ptr<Debugger> debugger_;
  137. #endif
  138. };
  139. using SessionPtr = std::shared_ptr<session::SessionBasic>;
  140. using NamedSummaryOutputs = std::map<std::string, std::pair<AnfNodePtr, int>>;
  141. } // namespace session
  142. } // namespace mindspore
  143. #endif // MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H