|
|
@@ -109,6 +109,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
MS_EXCEPTION_IF_NULL(kernel_graph_ptr); |
|
|
MS_EXCEPTION_IF_NULL(kernel_graph_ptr); |
|
|
|
|
|
bool eos_mode = ConfigManager::GetInstance().iter_num() == INT32_MAX; |
|
|
ReorderGetNext(kernel_graph_ptr); |
|
|
ReorderGetNext(kernel_graph_ptr); |
|
|
std::map<std::string, mindspore::ParameterPtr> switch_loop_input; |
|
|
std::map<std::string, mindspore::ParameterPtr> switch_loop_input; |
|
|
CreateSwitchOpParameters(kernel_graph_ptr, &switch_loop_input); |
|
|
CreateSwitchOpParameters(kernel_graph_ptr, &switch_loop_input); |
|
|
@@ -129,12 +130,17 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
auto orders = kernel_graph_ptr->execution_order(); |
|
|
|
|
|
|
|
|
const std::vector<CNodePtr> &orders = kernel_graph_ptr->execution_order(); |
|
|
if (orders.empty()) { |
|
|
if (orders.empty()) { |
|
|
MS_LOG(EXCEPTION) << "graph execution order is empty"; |
|
|
MS_LOG(EXCEPTION) << "graph execution order is empty"; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<CNodePtr> exec_order; |
|
|
std::vector<CNodePtr> exec_order; |
|
|
|
|
|
std::vector<uint32_t> getnext_active_streams; |
|
|
|
|
|
std::vector<uint32_t> fpbp_active_streams; |
|
|
|
|
|
CNodePtr getnext_cnode; |
|
|
|
|
|
uint32_t eos_done_event_id = UINT32_MAX; |
|
|
|
|
|
|
|
|
// getnext loop process |
|
|
// getnext loop process |
|
|
// getnext loop stream switch op |
|
|
// getnext loop stream switch op |
|
|
CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); |
|
|
CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); |
|
|
@@ -151,6 +157,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> |
|
|
exec_order.push_back(node); |
|
|
exec_order.push_back(node); |
|
|
AnfAlgo::SetStreamId(getnext_stream_id, exec_order[exec_order.size() - 1].get()); |
|
|
AnfAlgo::SetStreamId(getnext_stream_id, exec_order[exec_order.size() - 1].get()); |
|
|
if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { |
|
|
if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { |
|
|
|
|
|
getnext_cnode = node; |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
@@ -158,11 +165,52 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> |
|
|
// update getnext loop stream switch true_branch_stream attr |
|
|
// update getnext loop stream switch true_branch_stream attr |
|
|
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(getnext_stream_id), getnext_switch_app); |
|
|
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(getnext_stream_id), getnext_switch_app); |
|
|
|
|
|
|
|
|
// getnext loop send |
|
|
|
|
|
uint32_t getnext_event_id = resource_manager.ApplyNewEvent(); |
|
|
|
|
|
CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, getnext_event_id); |
|
|
|
|
|
AnfAlgo::SetStreamId(getnext_stream_id, send.get()); |
|
|
|
|
|
exec_order.push_back(send); |
|
|
|
|
|
|
|
|
// getnext loop fpbp start send |
|
|
|
|
|
uint32_t fpbp_start_event_id = resource_manager.ApplyNewEvent(); |
|
|
|
|
|
CNodePtr fpbp_start_send = CreateSendApplyKernel(kernel_graph_ptr, fpbp_start_event_id); |
|
|
|
|
|
AnfAlgo::SetStreamId(getnext_stream_id, fpbp_start_send.get()); |
|
|
|
|
|
exec_order.push_back(fpbp_start_send); |
|
|
|
|
|
|
|
|
|
|
|
if (eos_mode) { |
|
|
|
|
|
// getnext loop eos start send |
|
|
|
|
|
uint32_t eos_start_event_id = resource_manager.ApplyNewEvent(); |
|
|
|
|
|
CNodePtr eos_start_send = CreateSendApplyKernel(kernel_graph_ptr, eos_start_event_id); |
|
|
|
|
|
AnfAlgo::SetStreamId(getnext_stream_id, eos_start_send.get()); |
|
|
|
|
|
exec_order.push_back(eos_start_send); |
|
|
|
|
|
|
|
|
|
|
|
// End Of Sequence loop process |
|
|
|
|
|
// eos loop stream switch |
|
|
|
|
|
CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(eos_switch_app); |
|
|
|
|
|
uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream(); |
|
|
|
|
|
AnfAlgo::SetStreamId(eos_switch_stream_id, eos_switch_app.get()); |
|
|
|
|
|
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), eos_switch_app); |
|
|
|
|
|
exec_order.push_back(eos_switch_app); |
|
|
|
|
|
|
|
|
|
|
|
// eos loop eos start recv |
|
|
|
|
|
CNodePtr eos_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_start_event_id); |
|
|
|
|
|
uint32_t eos_stream_id = resource_manager.ApplyNewStream(); |
|
|
|
|
|
AnfAlgo::SetStreamId(eos_stream_id, eos_start_recv.get()); |
|
|
|
|
|
exec_order.push_back(eos_start_recv); |
|
|
|
|
|
|
|
|
|
|
|
// update eos loop stream switch true_branch_stream attr |
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(eos_stream_id), eos_switch_app); |
|
|
|
|
|
|
|
|
|
|
|
// EndOfSequence op |
|
|
|
|
|
CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(end_of_sequence_op); |
|
|
|
|
|
AnfAlgo::SetStreamId(eos_stream_id, end_of_sequence_op.get()); |
|
|
|
|
|
exec_order.push_back(end_of_sequence_op); |
|
|
|
|
|
|
|
|
|
|
|
// eos loop eos done send |
|
|
|
|
|
eos_done_event_id = resource_manager.ApplyNewEvent(); |
|
|
|
|
|
CNodePtr eos_done_send = CreateSendApplyKernel(kernel_graph_ptr, eos_done_event_id); |
|
|
|
|
|
AnfAlgo::SetStreamId(eos_stream_id, eos_done_send.get()); |
|
|
|
|
|
exec_order.push_back(eos_done_send); |
|
|
|
|
|
|
|
|
|
|
|
// eos loop stream active |
|
|
|
|
|
fpbp_active_streams.push_back(eos_switch_stream_id); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// fpbp loop process |
|
|
// fpbp loop process |
|
|
// fpbp loop stream switch |
|
|
// fpbp loop stream switch |
|
|
@@ -173,11 +221,11 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> |
|
|
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app); |
|
|
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app); |
|
|
exec_order.push_back(fpbp_switch_app); |
|
|
exec_order.push_back(fpbp_switch_app); |
|
|
|
|
|
|
|
|
// fpbp loop recv |
|
|
|
|
|
CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, getnext_event_id); |
|
|
|
|
|
|
|
|
// fpbp loop fpbp start recv |
|
|
|
|
|
CNodePtr fpbp_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, fpbp_start_event_id); |
|
|
uint32_t fpbp_stream_id = resource_manager.ApplyNewStream(); |
|
|
uint32_t fpbp_stream_id = resource_manager.ApplyNewStream(); |
|
|
AnfAlgo::SetStreamId(fpbp_stream_id, recv.get()); |
|
|
|
|
|
exec_order.push_back(recv); |
|
|
|
|
|
|
|
|
AnfAlgo::SetStreamId(fpbp_stream_id, fpbp_start_recv.get()); |
|
|
|
|
|
exec_order.push_back(fpbp_start_recv); |
|
|
|
|
|
|
|
|
// update fpbp loop stream switch true_branch_stream attr |
|
|
// update fpbp loop stream switch true_branch_stream attr |
|
|
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(fpbp_stream_id), fpbp_switch_app); |
|
|
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(fpbp_stream_id), fpbp_switch_app); |
|
|
@@ -190,40 +238,41 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> |
|
|
|
|
|
|
|
|
// fpbp memcpy |
|
|
// fpbp memcpy |
|
|
std::vector<CNodePtr> memcpy_list; |
|
|
std::vector<CNodePtr> memcpy_list; |
|
|
std::vector<CNodePtr> before_list; |
|
|
|
|
|
std::vector<CNodePtr> after_list; |
|
|
|
|
|
bool first_memcpy_found = false; |
|
|
|
|
|
|
|
|
std::vector<CNodePtr> other_list; |
|
|
CNodePtr cur_cnode = nullptr; |
|
|
CNodePtr cur_cnode = nullptr; |
|
|
for (size_t idx = i + 1; idx < orders.size(); idx++) { |
|
|
for (size_t idx = i + 1; idx < orders.size(); idx++) { |
|
|
cur_cnode = orders[idx]; |
|
|
cur_cnode = orders[idx]; |
|
|
if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) { |
|
|
if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) { |
|
|
memcpy_list.emplace_back(cur_cnode); |
|
|
memcpy_list.emplace_back(cur_cnode); |
|
|
first_memcpy_found = true; |
|
|
|
|
|
} else if (first_memcpy_found) { |
|
|
|
|
|
after_list.emplace_back(cur_cnode); |
|
|
|
|
|
} else { |
|
|
} else { |
|
|
before_list.emplace_back(cur_cnode); |
|
|
|
|
|
|
|
|
other_list.emplace_back(cur_cnode); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
(void)std::copy(before_list.begin(), before_list.end(), std::back_inserter(exec_order)); |
|
|
|
|
|
|
|
|
|
|
|
(void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order)); |
|
|
(void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order)); |
|
|
|
|
|
|
|
|
|
|
|
// fpbp loop eos done recv |
|
|
|
|
|
if (eos_mode) { |
|
|
|
|
|
CNodePtr eos_done_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_done_event_id); |
|
|
|
|
|
AnfAlgo::SetStreamId(fpbp_stream_id, eos_done_recv.get()); |
|
|
|
|
|
exec_order.push_back(eos_done_recv); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// stream active to activate getnext loop |
|
|
// stream active to activate getnext loop |
|
|
CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr); |
|
|
CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr); |
|
|
MS_EXCEPTION_IF_NULL(getnext_active_app); |
|
|
MS_EXCEPTION_IF_NULL(getnext_active_app); |
|
|
std::vector<uint32_t> getnext_active_streams = {getnext_switch_stream_id}; |
|
|
|
|
|
|
|
|
getnext_active_streams.push_back(getnext_switch_stream_id); |
|
|
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(getnext_active_streams), |
|
|
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(getnext_active_streams), |
|
|
getnext_active_app); |
|
|
getnext_active_app); |
|
|
exec_order.push_back(getnext_active_app); |
|
|
exec_order.push_back(getnext_active_app); |
|
|
|
|
|
|
|
|
// fpbp loop other ops |
|
|
// fpbp loop other ops |
|
|
(void)std::copy(after_list.begin(), after_list.end(), std::back_inserter(exec_order)); |
|
|
|
|
|
|
|
|
(void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order)); |
|
|
|
|
|
|
|
|
// stream active to activate fpbp loop |
|
|
|
|
|
|
|
|
// stream active to activate fpbp loop and eos loop |
|
|
CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr); |
|
|
CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr); |
|
|
MS_EXCEPTION_IF_NULL(fpbp_active_app); |
|
|
MS_EXCEPTION_IF_NULL(fpbp_active_app); |
|
|
// specific deal for common ctrl stream policy |
|
|
|
|
|
std::vector<uint32_t> fpbp_active_streams = {fpbp_switch_stream_id}; |
|
|
|
|
|
|
|
|
fpbp_active_streams.push_back(fpbp_switch_stream_id); |
|
|
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(fpbp_active_streams), fpbp_active_app); |
|
|
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(fpbp_active_streams), fpbp_active_app); |
|
|
exec_order.push_back(fpbp_active_app); |
|
|
exec_order.push_back(fpbp_active_app); |
|
|
|
|
|
|
|
|
@@ -323,6 +372,55 @@ CNodePtr KernelAdjust::CreateStreamActiveOp(const std::shared_ptr<session::Kerne |
|
|
return stream_active_others_app; |
|
|
return stream_active_others_app; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
CNodePtr KernelAdjust::CreatTupleGetItemNode(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, |
|
|
|
|
|
const CNodePtr &node, size_t output_idx) { |
|
|
|
|
|
auto idx = NewValueNode(SizeToInt(output_idx)); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(idx); |
|
|
|
|
|
auto imm = std::make_shared<Int32Imm>(SizeToInt(output_idx)); |
|
|
|
|
|
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm); |
|
|
|
|
|
idx->set_abstract(abstract_scalar); |
|
|
|
|
|
CNodePtr tuple_getitem = kernel_graph_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem); |
|
|
|
|
|
tuple_getitem->set_scope(node->scope()); |
|
|
|
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); |
|
|
|
|
|
TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx); |
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); |
|
|
|
|
|
return tuple_getitem; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
CNodePtr KernelAdjust::CreateEndOfSequenceOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, |
|
|
|
|
|
const CNodePtr &getnext_cnode) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph_ptr); |
|
|
|
|
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; |
|
|
|
|
|
selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT}); |
|
|
|
|
|
selected_kernel_builder.SetInputsDeviceType({kNumberTypeUInt8}); |
|
|
|
|
|
|
|
|
|
|
|
selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE); |
|
|
|
|
|
selected_kernel_builder.SetProcessor(kernel::Processor::AICPU); |
|
|
|
|
|
selected_kernel_builder.SetKernelType(KernelType::AICPU_KERNEL); |
|
|
|
|
|
|
|
|
|
|
|
selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); |
|
|
|
|
|
selected_kernel_builder.SetOutputsDeviceType({kNumberTypeUInt8}); |
|
|
|
|
|
// EndOfSequence |
|
|
|
|
|
auto end_of_sequence = std::make_shared<Primitive>(kEndOfSequence); |
|
|
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
|
|
inputs.push_back(NewValueNode(end_of_sequence)); |
|
|
|
|
|
// GetNext output 0 is EndOfSequence's input |
|
|
|
|
|
auto tuple_get_item = CreatTupleGetItemNode(kernel_graph_ptr, getnext_cnode, 0); |
|
|
|
|
|
inputs.push_back(tuple_get_item); |
|
|
|
|
|
CNodePtr end_of_sequence_node = kernel_graph_ptr->NewCNode(inputs); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(end_of_sequence_node); |
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), end_of_sequence_node.get()); |
|
|
|
|
|
std::vector<std::string> input_names = {"x"}; |
|
|
|
|
|
ValuePtr input_names_v = MakeValue(input_names); |
|
|
|
|
|
AnfAlgo::SetNodeAttr("input_names", input_names_v, end_of_sequence_node); |
|
|
|
|
|
std::vector<std::string> output_names = {"y"}; |
|
|
|
|
|
ValuePtr output_names_v = MakeValue(output_names); |
|
|
|
|
|
AnfAlgo::SetNodeAttr("output_names", output_names_v, end_of_sequence_node); |
|
|
|
|
|
end_of_sequence_node->set_abstract(tuple_get_item->abstract()); |
|
|
|
|
|
return end_of_sequence_node; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
CNodePtr KernelAdjust::CreateStreamAssignAddnOP( |
|
|
CNodePtr KernelAdjust::CreateStreamAssignAddnOP( |
|
|
const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, |
|
|
const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, |
|
|
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input) { |
|
|
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input) { |
|
|
|