From 58403932ee5d90a6b220c5858647c865dc331e0d Mon Sep 17 00:00:00 2001 From: gukecai Date: Thu, 11 Jun 2020 21:10:34 +0800 Subject: [PATCH 1/2] add sync bewteen hcom --- .../device/ascend/ascend_stream_assign.cc | 71 +++++++++++++++++++ .../device/ascend/ascend_stream_assign.h | 1 + 2 files changed, 72 insertions(+) diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc index 26ab826a7f..10d98856ec 100644 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc @@ -291,6 +291,74 @@ void AscendStreamAssign::FindAllReduceParallel(const shared_ptr &graph_ptr) { + MS_LOG(INFO) << "start"; + MS_EXCEPTION_IF_NULL(graph_ptr); + auto cnode_ptr_list = graph_ptr->execution_order(); + vector fusion_hcom_index; + vector orders; + for (size_t i = 0; i < cnode_ptr_list.size(); i++) { + auto cur_cnode = cnode_ptr_list[i]; + if (IsHcom(cur_cnode)) { + fusion_hcom_index.emplace_back(i); + } + } + + if (fusion_hcom_index.size() < 2) { + MS_LOG(INFO) << "fusion hcom size is less than 2, no need insert event between them"; + return; + } + + uint32_t first_index = fusion_hcom_index[0]; + uint32_t last_index = fusion_hcom_index[fusion_hcom_index.size() - 1]; + + uint32_t cur_event_id = total_event_num_; + uint32_t pre_hcom_stream_id = UINT32_MAX; + std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_index, std::back_inserter(orders)); + for (size_t i = first_index; i <= last_index; i++) { + auto cur_cnode = cnode_ptr_list[i]; + auto it = std::find(fusion_hcom_index.begin(), fusion_hcom_index.end(), i); + if (it == fusion_hcom_index.end()) { + orders.emplace_back(cur_cnode); + continue; + } + + auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode); + if (cur_hcom_stream_id == pre_hcom_stream_id) { + orders.emplace_back(cur_cnode); + continue; + } + + if (i == first_index) { + // first fusion hcom + orders.emplace_back(cur_cnode); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(send); + } else if (i == last_index) { + // last fusion hcom + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(recv); + orders.emplace_back(cur_cnode); + cur_event_id++; + } else { + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(recv); + cur_event_id++; + orders.emplace_back(cur_cnode); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(send); + } + + pre_hcom_stream_id = cur_hcom_stream_id; + } + + std::copy(cnode_ptr_list.begin() + last_index + 1, cnode_ptr_list.end(), std::back_inserter(orders)); + graph_ptr->set_execution_order(orders); + total_event_num_ = cur_event_id; + MS_LOG(INFO) << "after indsert between allreduce, total event nums[" << total_event_num_ << "]"; + MS_LOG(INFO) << "end"; +} + void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr &graph_ptr) { MS_LOG(INFO) << "start"; MS_EXCEPTION_IF_NULL(graph_ptr); @@ -324,6 +392,9 @@ void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptrset_execution_order(cnodes); total_event_num_ = cur_event_id; MS_LOG(INFO) << "after insert send/recv for hcom parallel, total event nums[" << total_event_num_ << "]"; + + // Insert Send/Recv between Hcom(such as:AllReduce1 Send1 Common Recv1 AllReduce2) + InsertSendRecvForDiffHcom(graph_ptr); MS_LOG(INFO) << "end"; } diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h index 7728e61fb0..4bb55a3d21 100644 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h @@ -95,6 +95,7 @@ class AscendStreamAssign { void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); void InsertSendRecvForIndependent(const std::shared_ptr &graph_ptr); void InsertSendRecvForHcomParallel(const std::shared_ptr &graph_ptr); + void InsertSendRecvForDiffHcom(const shared_ptr &graph_ptr); void GetNeedActiveStreams(const std::shared_ptr &graph_ptr); void ReorderIndependentOrders(const std::shared_ptr &graph_ptr); From cc39577c81015d2ff4e64937280d2add09661251 Mon Sep 17 00:00:00 2001 From: yanghaoran Date: Wed, 17 Jun 2020 20:14:22 +0800 Subject: [PATCH 2/2] Synchronize Ascend software suite 17 Jun 2020 --- graphengine | 2 +- mindspore/ops/_op_impl/tbe/matmul.py | 24 +++++++++++++++++------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/graphengine b/graphengine index 45ca7863ac..1350673d51 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit 45ca7863ac6410c8e2f83168481ddc6b43bcea33 +Subproject commit 1350673d51b3f8535bc217a7780e6a0b52ff9a41 diff --git a/mindspore/ops/_op_impl/tbe/matmul.py b/mindspore/ops/_op_impl/tbe/matmul.py index c29378f721..e773191ae8 100644 --- a/mindspore/ops/_op_impl/tbe/matmul.py +++ b/mindspore/ops/_op_impl/tbe/matmul.py @@ -23,16 +23,26 @@ matmul_op_info = TBERegOp("MatMul") \ .compute_cost(10) \ .kernel_name("matmul") \ .partial_flag(True) \ - .attr("transpose_a", "required", "bool", "all") \ - .attr("transpose_b", "required", "bool", "all") \ + .attr("transpose_x1", "required", "bool", "all") \ + .attr("transpose_x2", "required", "bool", "all") \ + .attr("offset_x", "optional", "int", "all") \ .input(0, "x1", False, "required", "all") \ .input(1, "x2", False, "required", "all") \ - .input(2, "x3", False, "optional", "all") \ + .input(2, "bias", False, "optional", "all") \ + .input(3, "offset_w", False, "optional", "all") \ .output(0, "y", False, "required", "all") \ - .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \ - .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.F32_FracNZ) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I8_Default, + DataType.I32_Default) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.I8_Default, + DataType.F16_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.I8_Default, + DataType.F32_FracNZ) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I8_Default, + DataType.F32_NHWC) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I8_Default, + DataType.F32_Default) \ + .dtype_format(DataType.I32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, DataType.I8_Default, + DataType.I32_NHWC) \ .get_op_info()