Browse Source

!2247 Synchronize Ascend software suite 17 Jun 2020

Merge pull request !2247 from yanghaoran/master
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
efa61b061e
4 changed files with 90 additions and 8 deletions
  1. +1
    -1
      graphengine
  2. +71
    -0
      mindspore/ccsrc/device/ascend/ascend_stream_assign.cc
  3. +1
    -0
      mindspore/ccsrc/device/ascend/ascend_stream_assign.h
  4. +17
    -7
      mindspore/ops/_op_impl/tbe/matmul.py

+ 1
- 1
graphengine

@@ -1 +1 @@
Subproject commit 45ca7863ac6410c8e2f83168481ddc6b43bcea33
Subproject commit 1350673d51b3f8535bc217a7780e6a0b52ff9a41

+ 71
- 0
mindspore/ccsrc/device/ascend/ascend_stream_assign.cc View File

@@ -291,6 +291,74 @@ void AscendStreamAssign::FindAllReduceParallel(const shared_ptr<session::KernelG
} }
} }


void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
MS_LOG(INFO) << "start";
MS_EXCEPTION_IF_NULL(graph_ptr);
auto cnode_ptr_list = graph_ptr->execution_order();
vector<uint32_t> fusion_hcom_index;
vector<CNodePtr> orders;
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
auto cur_cnode = cnode_ptr_list[i];
if (IsHcom(cur_cnode)) {
fusion_hcom_index.emplace_back(i);
}
}

if (fusion_hcom_index.size() < 2) {
MS_LOG(INFO) << "fusion hcom size is less than 2, no need insert event between them";
return;
}

uint32_t first_index = fusion_hcom_index[0];
uint32_t last_index = fusion_hcom_index[fusion_hcom_index.size() - 1];

uint32_t cur_event_id = total_event_num_;
uint32_t pre_hcom_stream_id = UINT32_MAX;
std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_index, std::back_inserter(orders));
for (size_t i = first_index; i <= last_index; i++) {
auto cur_cnode = cnode_ptr_list[i];
auto it = std::find(fusion_hcom_index.begin(), fusion_hcom_index.end(), i);
if (it == fusion_hcom_index.end()) {
orders.emplace_back(cur_cnode);
continue;
}

auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode);
if (cur_hcom_stream_id == pre_hcom_stream_id) {
orders.emplace_back(cur_cnode);
continue;
}

if (i == first_index) {
// first fusion hcom
orders.emplace_back(cur_cnode);
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(send);
} else if (i == last_index) {
// last fusion hcom
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(recv);
orders.emplace_back(cur_cnode);
cur_event_id++;
} else {
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(recv);
cur_event_id++;
orders.emplace_back(cur_cnode);
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(send);
}

pre_hcom_stream_id = cur_hcom_stream_id;
}

std::copy(cnode_ptr_list.begin() + last_index + 1, cnode_ptr_list.end(), std::back_inserter(orders));
graph_ptr->set_execution_order(orders);
total_event_num_ = cur_event_id;
MS_LOG(INFO) << "after indsert between allreduce, total event nums[" << total_event_num_ << "]";
MS_LOG(INFO) << "end";
}

void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) { void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
MS_LOG(INFO) << "start"; MS_LOG(INFO) << "start";
MS_EXCEPTION_IF_NULL(graph_ptr); MS_EXCEPTION_IF_NULL(graph_ptr);
@@ -324,6 +392,9 @@ void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspor
graph_ptr->set_execution_order(cnodes); graph_ptr->set_execution_order(cnodes);
total_event_num_ = cur_event_id; total_event_num_ = cur_event_id;
MS_LOG(INFO) << "after insert send/recv for hcom parallel, total event nums[" << total_event_num_ << "]"; MS_LOG(INFO) << "after insert send/recv for hcom parallel, total event nums[" << total_event_num_ << "]";

// Insert Send/Recv between Hcom(such as:AllReduce1 Send1 Common Recv1 AllReduce2)
InsertSendRecvForDiffHcom(graph_ptr);
MS_LOG(INFO) << "end"; MS_LOG(INFO) << "end";
} }




+ 1
- 0
mindspore/ccsrc/device/ascend/ascend_stream_assign.h View File

@@ -95,6 +95,7 @@ class AscendStreamAssign {
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams); void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr); void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr); void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void InsertSendRecvForDiffHcom(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr);
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr); void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void ReorderIndependentOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr); void ReorderIndependentOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);




+ 17
- 7
mindspore/ops/_op_impl/tbe/matmul.py View File

@@ -23,16 +23,26 @@ matmul_op_info = TBERegOp("MatMul") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("matmul") \ .kernel_name("matmul") \
.partial_flag(True) \ .partial_flag(True) \
.attr("transpose_a", "required", "bool", "all") \
.attr("transpose_b", "required", "bool", "all") \
.attr("transpose_x1", "required", "bool", "all") \
.attr("transpose_x2", "required", "bool", "all") \
.attr("offset_x", "optional", "int", "all") \
.input(0, "x1", False, "required", "all") \ .input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \ .input(1, "x2", False, "required", "all") \
.input(2, "x3", False, "optional", "all") \
.input(2, "bias", False, "optional", "all") \
.input(3, "offset_w", False, "optional", "all") \
.output(0, "y", False, "required", "all") \ .output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I8_Default,
DataType.I32_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.I8_Default,
DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.I8_Default,
DataType.F32_FracNZ) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I8_Default,
DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I8_Default,
DataType.F32_Default) \
.dtype_format(DataType.I32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, DataType.I8_Default,
DataType.I32_NHWC) \
.get_op_info() .get_op_info()






Loading…
Cancel
Save