You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_stream_assign.h 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_
  17. #define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_
  18. #include <functional>
  19. #include <unordered_map>
  20. #include <string>
  21. #include <vector>
  22. #include <memory>
  23. #include <unordered_set>
  24. #include "runtime/base.h"
  25. #include "runtime/rt_model.h"
  26. #include "runtime/stream.h"
  27. #include "session/kernel_graph.h"
  28. namespace mindspore {
  29. namespace device {
  30. namespace ascend {
  31. using std::map;
  32. using std::shared_ptr;
  33. using std::unordered_map;
  34. using std::unordered_set;
  35. using std::vector;
  36. class AscendStreamAssign {
  37. public:
  38. static AscendStreamAssign &GetInstance() {
  39. static AscendStreamAssign instance; // Guaranteed to be destroyed.
  40. return instance;
  41. }
  42. AscendStreamAssign(const AscendStreamAssign &) = delete;
  43. AscendStreamAssign &operator=(const AscendStreamAssign &) = delete;
  44. uint32_t GetTotalStreamNum() const;
  45. // new stream policy
  46. uint32_t total_common_stream_num() const { return total_common_stream_num_; }
  47. uint32_t total_independ_stream_num() const { return total_independ_stream_num_; }
  48. uint32_t total_event_num() const { return total_event_num_; }
  49. void InsertActiveNew(const std::shared_ptr<session::KernelGraph> &graph_ptr);
  50. void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph> &graph_ptr);
  51. void ResetNew();
  52. void AssignStreamNew(const std::shared_ptr<session::KernelGraph> &graph_ptr);
  53. bool IsIndependentNode(const CNodePtr &node_ptr);
  54. const std::unordered_map<uint32_t, uint32_t> &logic_to_independent_map() { return logic_to_independent_map_; }
  55. const std::unordered_map<uint32_t, uint32_t> &logic_to_physic_map() { return logic_to_physic_map_; }
  56. const std::vector<std::vector<uint32_t>> &inner_parallel_streams() { return inner_parallel_streams_; }
  57. void GetWaitStreams(vector<uint32_t> *wait_active_stream_list);
  58. const std::vector<uint32_t> &hcom_streams() { return hcom_stream_list_; }
  59. CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
  60. uint32_t stream_id);
  61. CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
  62. uint32_t stream_id);
  63. private:
  64. AscendStreamAssign() = default;
  65. ~AscendStreamAssign() = default;
  66. vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
  67. const CNodePtr &node);
  68. bool IsHcom(const CNodePtr &apply_kernel);
  69. bool IsProcessed(uint32_t logic_id);
  70. void TransLogicToPhysic(const vector<uint32_t> &logic_ids, vector<uint32_t> *physic_ids);
  71. void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index,
  72. uint32_t *cur_stream_id);
  73. void RecordIdMap(uint32_t logic_id, uint32_t physic_id);
  74. void UpdateStreamActive(const CNodePtr &active_ptr);
  75. void UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr);
  76. bool IsTaskSink();
  77. void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id);
  78. void UpdateStreamId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
  79. void UpdateEventId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
  80. void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);
  81. void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id);
  82. uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr);
  83. void SetCommonStreamNum(uint32_t cur_stream_id);
  84. void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
  85. bool IsProcessedParallelStream(uint32_t stream_id);
  86. void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
  87. void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
  88. void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
  89. void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
  90. uint32_t total_common_stream_num_{0};
  91. uint32_t total_independ_stream_num_{0};
  92. uint32_t total_event_num_{0};
  93. uint32_t first_physic_id_{UINT32_MAX};
  94. uint32_t first_logic_id_{UINT32_MAX};
  95. uint32_t independent_id_{UINT32_MAX};
  96. vector<uint32_t> processed_logic_id_{};
  97. std::unordered_map<uint32_t, uint32_t> logic_to_physic_map_{}; // key:logic id, value: first physic id
  98. std::unordered_map<uint32_t, uint32_t> logic_to_independent_map_{}; // key:logic id, value: dependent id
  99. std::vector<uint32_t> independent_before_physic_id_{}; // record independent id before first physic id
  100. std::vector<std::vector<uint32_t>> inner_parallel_streams_{};
  101. std::vector<uint32_t> processed_parallel_streams_{};
  102. std::vector<uint32_t> hcom_stream_list_{};
  103. std::vector<uint32_t> need_first_active_streams_{};
  104. // new policy end
  105. };
  106. } // namespace ascend
  107. } // namespace device
  108. } // namespace mindspore
  109. #endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_