Merge pull request !4610 from gukecai/reverttags/v0.7.0-beta
| @@ -1163,31 +1163,5 @@ bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) { | |||||
| auto input = node->input(kAnfPrimitiveIndex); | auto input = node->input(kAnfPrimitiveIndex); | ||||
| return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch); | return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch); | ||||
| } | } | ||||
| bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (AnfAlgo::GetKernelType(node) != AICPU_KERNEL) { | |||||
| return false; | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { | |||||
| MS_LOG(INFO) << "GetNext should not be independent node"; | |||||
| return false; | |||||
| } | |||||
| uint32_t input_nums = AnfAlgo::GetInputTensorNum(node); | |||||
| if (input_nums == 0) { | |||||
| return true; | |||||
| } | |||||
| auto inputs = node->inputs(); | |||||
| for (size_t i = 1; i < inputs.size(); i++) { | |||||
| if (!inputs[i]->isa<ValueNode>()) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -212,7 +212,6 @@ class AnfRuntimeAlgorithm { | |||||
| // get fix output precision from prev node, input_idx is the input index of current node related to prev node. | // get fix output precision from prev node, input_idx is the input index of current node related to prev node. | ||||
| static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); | static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); | ||||
| static bool IsCondControlKernel(const CNodePtr &node); | static bool IsCondControlKernel(const CNodePtr &node); | ||||
| static bool IsIndependentNode(const CNodePtr &node); | |||||
| }; | }; | ||||
| } // namespace session | } // namespace session | ||||
| using AnfAlgo = session::AnfRuntimeAlgorithm; | using AnfAlgo = session::AnfRuntimeAlgorithm; | ||||
| @@ -180,32 +180,20 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto | |||||
| if (inputs_params == nullptr) { | if (inputs_params == nullptr) { | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| if (inputs_params->size() < 3) { | |||||
| if (inputs_params->size() < 2) { | |||||
| MS_LOG(EXCEPTION) << "Illegal inputs_params size"; | MS_LOG(EXCEPTION) << "Illegal inputs_params size"; | ||||
| } | } | ||||
| // update current loop tensor to 0 per iterator | |||||
| auto cur_loop_tensor = (*inputs_params)[0]; | |||||
| MS_EXCEPTION_IF_NULL(cur_loop_tensor); | |||||
| auto *cur_val = static_cast<int32_t *>(cur_loop_tensor->data_c()); | |||||
| MS_EXCEPTION_IF_NULL(cur_val); | |||||
| *cur_val = 0; | |||||
| cur_loop_tensor->set_dirty(true); | |||||
| auto tensor = (*inputs_params)[0]; | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| auto *val = static_cast<int32_t *>(tensor->data_c()); | |||||
| MS_EXCEPTION_IF_NULL(val); | |||||
| *val = 0; | |||||
| tensor->set_dirty(true); | |||||
| // set loop_count to zero | // set loop_count to zero | ||||
| MS_EXCEPTION_IF_NULL(inputs); | MS_EXCEPTION_IF_NULL(inputs); | ||||
| inputs->push_back(cur_loop_tensor); | |||||
| inputs->push_back(tensor); | |||||
| // update next loop tensor to 0 per iterator | |||||
| auto next_loop_tensor = (*inputs_params)[1]; | |||||
| MS_EXCEPTION_IF_NULL(next_loop_tensor); | |||||
| auto *next_val = static_cast<int32_t *>(next_loop_tensor->data_c()); | |||||
| MS_EXCEPTION_IF_NULL(next_val); | |||||
| *next_val = 0; | |||||
| next_loop_tensor->set_dirty(true); | |||||
| // set loop_count to zero | |||||
| MS_EXCEPTION_IF_NULL(inputs); | |||||
| inputs->push_back(next_loop_tensor); | |||||
| auto epoch_tensor = (*inputs_params)[2]; | |||||
| auto epoch_tensor = (*inputs_params)[1]; | |||||
| MS_EXCEPTION_IF_NULL(epoch_tensor); | MS_EXCEPTION_IF_NULL(epoch_tensor); | ||||
| auto *epoch_val = static_cast<int32_t *>(epoch_tensor->data_c()); | auto *epoch_val = static_cast<int32_t *>(epoch_tensor->data_c()); | ||||
| MS_EXCEPTION_IF_NULL(epoch_val); | MS_EXCEPTION_IF_NULL(epoch_val); | ||||
| @@ -942,7 +930,7 @@ bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor | |||||
| void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | ||||
| const std::vector<tensor::TensorPtr> &inputs_const) const { | const std::vector<tensor::TensorPtr> &inputs_const) const { | ||||
| std::vector<tensor::TensorPtr> inputs(inputs_const); | std::vector<tensor::TensorPtr> inputs(inputs_const); | ||||
| size_t input_ctrl_size = 3; | |||||
| size_t input_ctrl_size = 2; | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| if (kernel_graph->input_ctrl_tensors()) { | if (kernel_graph->input_ctrl_tensors()) { | ||||
| input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); | input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); | ||||
| @@ -952,7 +940,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap | |||||
| auto params = AnfAlgo::GetAllOutput(input_node); | auto params = AnfAlgo::GetAllOutput(input_node); | ||||
| std::copy(params.begin(), params.end(), std::back_inserter(input_nodes)); | std::copy(params.begin(), params.end(), std::back_inserter(input_nodes)); | ||||
| } | } | ||||
| if ((inputs.size() + input_ctrl_size) - 3 != input_nodes.size()) { | |||||
| if ((inputs.size() + input_ctrl_size) - 2 != input_nodes.size()) { | |||||
| MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() | MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() | ||||
| << ", input_ctrl_size:" << input_ctrl_size; | << ", input_ctrl_size:" << input_ctrl_size; | ||||
| } | } | ||||
| @@ -42,9 +42,6 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) | |||||
| InsertStreamActive(graph_ptr); | InsertStreamActive(graph_ptr); | ||||
| InsertEventForHcomParallel(graph_ptr); | InsertEventForHcomParallel(graph_ptr); | ||||
| InsertEventForIndependentParallel(graph_ptr); | InsertEventForIndependentParallel(graph_ptr); | ||||
| GetIndependentMaxTarget(graph_ptr); | |||||
| InsertCtrlForIndependentParallel(graph_ptr); | |||||
| GetNeedActiveStreams(graph_ptr); | GetNeedActiveStreams(graph_ptr); | ||||
| graph_ptr->PrintGraphExecuteOrder(); | graph_ptr->PrintGraphExecuteOrder(); | ||||
| CheckResourceAssign(graph_ptr); | CheckResourceAssign(graph_ptr); | ||||
| @@ -69,7 +66,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr> | |||||
| for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { | for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { | ||||
| auto cur_cnode_ptr = cnode_ptr_list[i]; | auto cur_cnode_ptr = cnode_ptr_list[i]; | ||||
| MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | ||||
| if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) { | |||||
| if (IsIndependentNode(cur_cnode_ptr)) { | |||||
| independents.emplace_back(cur_cnode_ptr); | independents.emplace_back(cur_cnode_ptr); | ||||
| } else { | } else { | ||||
| others.emplace_back(cur_cnode_ptr); | others.emplace_back(cur_cnode_ptr); | ||||
| @@ -136,7 +133,7 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) { | |||||
| if (IsIndependentNode(cur_cnode_ptr)) { | |||||
| exit_independent = true; | exit_independent = true; | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -168,7 +165,7 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra | |||||
| if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { | if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) { | |||||
| if (IsIndependentNode(cur_cnode_ptr)) { | |||||
| AssignIndependentStreamId(cur_cnode_ptr); | AssignIndependentStreamId(cur_cnode_ptr); | ||||
| } | } | ||||
| } | } | ||||
| @@ -245,6 +242,33 @@ void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr | |||||
| } | } | ||||
| } | } | ||||
| bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(node_ptr); | |||||
| if (AnfAlgo::GetKernelType(node_ptr) != AICPU_KERNEL) { | |||||
| return false; | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(node_ptr) == kGetNextOpName) { | |||||
| MS_LOG(INFO) << "GetNext should not be independent node"; | |||||
| return false; | |||||
| } | |||||
| uint32_t input_nums = AnfAlgo::GetInputTensorNum(node_ptr); | |||||
| if (input_nums == 0) { | |||||
| MS_LOG(INFO) << "Node " << node_ptr->fullname_with_scope() << " is independent, as inputs nums is zero"; | |||||
| return true; | |||||
| } | |||||
| auto inputs = node_ptr->inputs(); | |||||
| for (size_t i = 1; i < inputs.size(); i++) { | |||||
| if (!inputs[i]->isa<ValueNode>()) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "Node " << node_ptr->fullname_with_scope() << " is independent, as inputs is all value node"; | |||||
| return true; | |||||
| } | |||||
| // section 3: | // section 3: | ||||
| void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr) { | void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr) { | ||||
| MS_LOG(INFO) << "Start"; | MS_LOG(INFO) << "Start"; | ||||
| @@ -269,11 +293,13 @@ void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph | |||||
| CNodePtr pre_cnode_ptr = nullptr; | CNodePtr pre_cnode_ptr = nullptr; | ||||
| uint32_t pre_stream_id = UINT32_MAX; | uint32_t pre_stream_id = UINT32_MAX; | ||||
| bool independent_flag = !(independent_stream_map_.empty()); | |||||
| bool hcom_flag = !(hcom_stream_map_.empty()); | |||||
| auto cnode_ptr_list = graph_ptr->execution_order(); | auto cnode_ptr_list = graph_ptr->execution_order(); | ||||
| for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { | for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { | ||||
| cur_cnode_ptr = cnode_ptr_list[i]; | cur_cnode_ptr = cnode_ptr_list[i]; | ||||
| MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | ||||
| if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) { | |||||
| if (IsIndependentNode(cur_cnode_ptr)) { | |||||
| update_cnode_list.emplace_back(cur_cnode_ptr); | update_cnode_list.emplace_back(cur_cnode_ptr); | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -296,7 +322,7 @@ void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph | |||||
| update_cnode_list.emplace_back(active_ptr); | update_cnode_list.emplace_back(active_ptr); | ||||
| } | } | ||||
| if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { | |||||
| if ((independent_flag || hcom_flag) && (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName)) { | |||||
| MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel"; | MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel"; | ||||
| UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, &update_cnode_list); | UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, &update_cnode_list); | ||||
| } else { | } else { | ||||
| @@ -320,10 +346,8 @@ void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph | |||||
| uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); | uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); | ||||
| if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { | if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { | ||||
| if (AnfAlgo::HasNodeAttr(kAttrTrueBranchStream, cur_cnode_ptr)) { | |||||
| auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrTrueBranchStream); | |||||
| processed_streams_.emplace(true_stream_id); | |||||
| } | |||||
| auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrTrueBranchStream); | |||||
| processed_streams_.emplace(true_stream_id); | |||||
| if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) { | if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) { | ||||
| continue; | continue; | ||||
| @@ -341,78 +365,46 @@ void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph | |||||
| void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr, | void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr, | ||||
| vector<CNodePtr> *orders) { | vector<CNodePtr> *orders) { | ||||
| orders->emplace_back(switch_ptr); | |||||
| if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, switch_ptr)) { | if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, switch_ptr)) { | ||||
| orders->emplace_back(switch_ptr); | |||||
| return; | return; | ||||
| } | } | ||||
| auto need_active = AnfAlgo::GetNodeAttr<bool>(switch_ptr, kStreamNeedActivedFirst); | auto need_active = AnfAlgo::GetNodeAttr<bool>(switch_ptr, kStreamNeedActivedFirst); | ||||
| if (!need_active) { | if (!need_active) { | ||||
| orders->emplace_back(switch_ptr); | |||||
| return; | return; | ||||
| } | } | ||||
| if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, switch_ptr)) { | |||||
| orders->emplace_back(switch_ptr); | |||||
| return; | |||||
| MS_EXCEPTION_IF_NULL(switch_ptr); | |||||
| auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrTrueBranchStream); | |||||
| MS_LOG(INFO) << "Streamswtich stream id:" << AnfAlgo::GetStreamId(switch_ptr) | |||||
| << "; active stream id:" << true_stream_id; | |||||
| CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); | |||||
| AnfAlgo::SetStreamId(true_stream_id, active_ptr.get()); | |||||
| vector<uint32_t> active_ids; | |||||
| // active indepdent stream | |||||
| for (const auto &item : independent_stream_map_) { | |||||
| active_ids.emplace_back(item.first); | |||||
| } | } | ||||
| auto kind = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrStreamSwitchKind); | |||||
| if (kind == kEosStreamSwitch || kind == kGetNextStreamSwitch) { | |||||
| orders->emplace_back(switch_ptr); | |||||
| return; | |||||
| // active hcom stream | |||||
| for (const auto &item : hcom_stream_map_) { | |||||
| active_ids.emplace_back(item.first); | |||||
| } | } | ||||
| AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_ids), active_ptr); | |||||
| if (kind == kIndependentStreamSwitch) { | |||||
| bool independent_empty = independent_stream_map_.empty(); | |||||
| // if indepdent empty: delete independent streamswitch | |||||
| if (!independent_empty) { | |||||
| for (const auto &item : independent_stream_map_) { | |||||
| // first independetn stream id is minimum and order by std map; | |||||
| auto first_independent_stream = item.first; | |||||
| AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(first_independent_stream), switch_ptr); | |||||
| orders->emplace_back(switch_ptr); | |||||
| break; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "independent stream switch exit, but independent stream is empty"; | |||||
| } | |||||
| // update processed stream | |||||
| independent_stream_activated_ = true; | |||||
| for (const auto &item : independent_stream_map_) { | |||||
| processed_streams_.emplace(item.first); | |||||
| } | |||||
| return; | |||||
| // update processed stream | |||||
| independent_stream_activated_ = true; | |||||
| for (const auto &item : independent_stream_map_) { | |||||
| processed_streams_.emplace(item.first); | |||||
| } | } | ||||
| if (kind == kFpBpStreamSwitch) { | |||||
| bool hcom_empty = hcom_stream_map_.empty(); | |||||
| if (hcom_empty) { | |||||
| orders->emplace_back(switch_ptr); | |||||
| return; | |||||
| } | |||||
| if (!AnfAlgo::HasNodeAttr(kAttrTrueBranchStream, switch_ptr)) { | |||||
| orders->emplace_back(switch_ptr); | |||||
| MS_LOG(WARNING) << "FpBp StreamSwitch has no true branch attr"; | |||||
| return; | |||||
| } | |||||
| auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrTrueBranchStream); | |||||
| MS_LOG(INFO) << "Streamswtich stream id:" << AnfAlgo::GetStreamId(switch_ptr) | |||||
| << "; active stream id:" << true_stream_id; | |||||
| CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); | |||||
| AnfAlgo::SetStreamId(true_stream_id, active_ptr.get()); | |||||
| vector<uint32_t> active_ids; | |||||
| // active hcom stream | |||||
| for (const auto &item : hcom_stream_map_) { | |||||
| active_ids.emplace_back(item.first); | |||||
| } | |||||
| AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_ids), active_ptr); | |||||
| hcom_stream_activated_ = true; | |||||
| for (const auto &item : hcom_stream_map_) { | |||||
| processed_streams_.emplace(item.first); | |||||
| } | |||||
| orders->emplace_back(switch_ptr); | |||||
| orders->emplace_back(active_ptr); | |||||
| hcom_stream_activated_ = true; | |||||
| for (const auto &item : hcom_stream_map_) { | |||||
| processed_streams_.emplace(item.first); | |||||
| } | } | ||||
| orders->emplace_back(active_ptr); | |||||
| } | } | ||||
| bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) { | bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) { | ||||
| @@ -640,7 +632,7 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG | |||||
| auto it = cnodes.begin(); | auto it = cnodes.begin(); | ||||
| while (it != cnodes.end()) { | while (it != cnodes.end()) { | ||||
| MS_EXCEPTION_IF_NULL(*it); | MS_EXCEPTION_IF_NULL(*it); | ||||
| if (AnfAlgo::IsIndependentNode(*it)) { | |||||
| if (IsIndependentNode(*it)) { | |||||
| MS_LOG(INFO) << "Deal independent op[" << (*it)->DebugString() << "]"; | MS_LOG(INFO) << "Deal independent op[" << (*it)->DebugString() << "]"; | ||||
| CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); | CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); | ||||
| it = cnodes.insert(it + 1, send_cnode_ptr); | it = cnodes.insert(it + 1, send_cnode_ptr); | ||||
| @@ -668,129 +660,6 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG | |||||
| MS_LOG(INFO) << "End"; | MS_LOG(INFO) << "End"; | ||||
| } | } | ||||
| void AscendStreamAssign::GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &graph_ptr) { | |||||
| MS_LOG(INFO) << "Start"; | |||||
| auto cnode_ptr_list = graph_ptr->execution_order(); | |||||
| for (size_t i = 0; i < cnode_ptr_list.size(); i++) { | |||||
| auto cur_node = cnode_ptr_list[i]; | |||||
| auto key = cur_node.get(); | |||||
| if (!AnfAlgo::IsIndependentNode(cur_node)) { | |||||
| continue; | |||||
| } | |||||
| bool flag = false; | |||||
| for (size_t j = cnode_ptr_list.size() - 1; j > i; j--) { | |||||
| auto target_node = cnode_ptr_list[j]; | |||||
| auto inputs = target_node->inputs(); | |||||
| for (size_t m = 1; m < inputs.size(); m++) { | |||||
| auto input = inputs[m]; | |||||
| if (opt::IsNopNode(input)) { | |||||
| CNodePtr cnode = input->cast<CNodePtr>(); | |||||
| auto new_inputs = cnode->inputs(); | |||||
| for (size_t k = 1; k < new_inputs.size(); k++) { | |||||
| auto new_real_input = AnfAlgo::VisitKernel(new_inputs[k], 0); | |||||
| if (key == new_real_input.first.get()) { | |||||
| MS_LOG(INFO) << "Nop node find max target op:" << AnfAlgo::GetCNodeName(cur_node); | |||||
| independent_targets_.emplace(target_node.get()); | |||||
| flag = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| auto real_input = AnfAlgo::VisitKernel(input, 0); | |||||
| if (key == real_input.first.get()) { | |||||
| MS_LOG(INFO) << "Find max target op:" << AnfAlgo::GetCNodeName(cur_node); | |||||
| independent_targets_.emplace(target_node.get()); | |||||
| flag = true; | |||||
| } | |||||
| } | |||||
| if (flag) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "End"; | |||||
| } | |||||
| uint32_t AscendStreamAssign::GetIndexByKey(const NotNull<KernelGraphPtr> &graph_ptr, const CNodeKey &key) { | |||||
| auto &exe_orders = graph_ptr->execution_order(); | |||||
| for (uint32_t i = 0; i < exe_orders.size(); i++) { | |||||
| CNodeKey node_key = exe_orders[i].get(); | |||||
| if (node_key == key) { | |||||
| return i; | |||||
| } | |||||
| } | |||||
| return UINT32_MAX; | |||||
| } | |||||
| uint32_t AscendStreamAssign::GetMaxIndexTarget(const NotNull<KernelGraphPtr> &graph_ptr) { | |||||
| if (independent_targets_.empty()) { | |||||
| return UINT32_MAX; | |||||
| } | |||||
| std::set<uint32_t> indexs; | |||||
| for (const auto &key : independent_targets_) { | |||||
| auto index = GetIndexByKey(graph_ptr, key); | |||||
| if (index == UINT32_MAX) { | |||||
| MS_LOG(EXCEPTION) << "graph has no correspond key"; | |||||
| } | |||||
| indexs.emplace(index); | |||||
| } | |||||
| return *(std::max_element(indexs.begin(), indexs.end())); | |||||
| } | |||||
| uint32_t AscendStreamAssign::GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr) { | |||||
| auto &exe_orders = graph_ptr->execution_order(); | |||||
| for (const auto &item : exe_orders) { | |||||
| if (AnfAlgo::GetCNodeName(item) == kStreamSwitchOpName) { | |||||
| if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, item)) { | |||||
| continue; | |||||
| } | |||||
| auto kind = AnfAlgo::GetNodeAttr<uint32_t>(item, kAttrStreamSwitchKind); | |||||
| if (kind == kIndependentStreamSwitch) { | |||||
| return AnfAlgo::GetStreamId(item); | |||||
| } | |||||
| } | |||||
| } | |||||
| return kInvalidStreamId; | |||||
| } | |||||
| void AscendStreamAssign::InsertCtrlForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr) { | |||||
| if (independent_targets_.empty()) { | |||||
| return; | |||||
| } | |||||
| uint32_t independent_switch_stream = GetIndependentStreamSwitchStreamId(graph_ptr); | |||||
| if (independent_switch_stream == kInvalidStreamId) { | |||||
| return; | |||||
| } | |||||
| auto max_index = GetMaxIndexTarget(graph_ptr); | |||||
| auto &exe_orders = graph_ptr->execution_order(); | |||||
| if (max_index >= exe_orders.size()) { | |||||
| MS_LOG(EXCEPTION) << "max target index:" << max_index << " is greater than graph orders size:" << exe_orders.size(); | |||||
| } | |||||
| auto max_node_stream = AnfAlgo::GetStreamId(exe_orders[max_index]); | |||||
| CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); | |||||
| // 1.set stream id | |||||
| AnfAlgo::SetStreamId(max_node_stream, active_ptr.get()); | |||||
| // 2.set active stream ids | |||||
| std::vector<uint32_t> active_index_list{independent_switch_stream}; | |||||
| AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr); | |||||
| std::vector<CNodePtr> update_cnode_list; | |||||
| std::copy(exe_orders.begin(), exe_orders.begin() + max_index + 1, std::back_inserter(update_cnode_list)); | |||||
| update_cnode_list.emplace_back(active_ptr); | |||||
| std::copy(exe_orders.begin() + max_index + 1, exe_orders.end(), std::back_inserter(update_cnode_list)); | |||||
| graph_ptr->set_execution_order(update_cnode_list); | |||||
| } | |||||
| // section7 | // section7 | ||||
| void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr) { | void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr) { | ||||
| CNodePtr cur_cnode_ptr = nullptr; | CNodePtr cur_cnode_ptr = nullptr; | ||||
| @@ -1048,7 +917,6 @@ void AscendStreamAssign::Reset() { | |||||
| stream_groups_.clear(); | stream_groups_.clear(); | ||||
| stream_relations_.clear(); | stream_relations_.clear(); | ||||
| event_map_.clear(); | event_map_.clear(); | ||||
| independent_targets_.clear(); | |||||
| } | } | ||||
| // section 10 | // section 10 | ||||
| @@ -39,7 +39,6 @@ using std::shared_ptr; | |||||
| using std::unordered_map; | using std::unordered_map; | ||||
| using std::unordered_set; | using std::unordered_set; | ||||
| using std::vector; | using std::vector; | ||||
| using CNodeKey = void *; | |||||
| const uint32_t kInvalidStreamId = UINT32_MAX; | const uint32_t kInvalidStreamId = UINT32_MAX; | ||||
| const uint32_t kInvalidEventId = UINT32_MAX; | const uint32_t kInvalidEventId = UINT32_MAX; | ||||
| class AscendResourceMng { | class AscendResourceMng { | ||||
| @@ -109,6 +108,8 @@ class AscendStreamAssign { | |||||
| void AssignStream(const NotNull<KernelGraphPtr> &graph_ptr); | void AssignStream(const NotNull<KernelGraphPtr> &graph_ptr); | ||||
| void GetHcomStreams(std::vector<uint32_t> *streams); | void GetHcomStreams(std::vector<uint32_t> *streams); | ||||
| void GetWaitStreams(vector<uint32_t> *wait_active_stream_list); | void GetWaitStreams(vector<uint32_t> *wait_active_stream_list); | ||||
| CNodePtr CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id); | |||||
| CNodePtr CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id); | |||||
| const std::vector<std::vector<uint32_t>> &get_stream_group() const { return stream_groups_; } | const std::vector<std::vector<uint32_t>> &get_stream_group() const { return stream_groups_; } | ||||
| const std::map<CNodePtr, CNodePtr> &get_event_map() const { return event_map_; } | const std::map<CNodePtr, CNodePtr> &get_event_map() const { return event_map_; } | ||||
| @@ -116,8 +117,6 @@ class AscendStreamAssign { | |||||
| AscendStreamAssign() = default; | AscendStreamAssign() = default; | ||||
| ~AscendStreamAssign() = default; | ~AscendStreamAssign() = default; | ||||
| void Reset(); | void Reset(); | ||||
| CNodePtr CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id); | |||||
| CNodePtr CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id); | |||||
| void CheckResourceAssign(const NotNull<KernelGraphPtr> &graph_ptr); | void CheckResourceAssign(const NotNull<KernelGraphPtr> &graph_ptr); | ||||
| void CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ptr); | void CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ptr); | ||||
| void CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr); | void CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr); | ||||
| @@ -131,7 +130,6 @@ class AscendStreamAssign { | |||||
| void UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr, | void UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr, | ||||
| vector<CNodePtr> *orders); | vector<CNodePtr> *orders); | ||||
| void InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr); | void InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr); | ||||
| void InsertCtrlForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr); | void InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr); | ||||
| void InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr); | void InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr); | ||||
| void InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr); | void InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr); | ||||
| @@ -143,10 +141,6 @@ class AscendStreamAssign { | |||||
| void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr); | void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr); | ||||
| void GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr); | void GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr); | ||||
| void ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr); | void ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr); | ||||
| uint32_t GetMaxIndexTarget(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| uint32_t GetIndexByKey(const NotNull<KernelGraphPtr> &graph_ptr, const CNodeKey &key); | |||||
| uint32_t GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| bool IsTaskSink(); | bool IsTaskSink(); | ||||
| bool IsHcom(const CNodePtr &cur_cnode_ptr); | bool IsHcom(const CNodePtr &cur_cnode_ptr); | ||||
| @@ -177,7 +171,6 @@ class AscendStreamAssign { | |||||
| std::map<uint32_t, uint32_t> common_stream_map_{}; | std::map<uint32_t, uint32_t> common_stream_map_{}; | ||||
| std::set<uint32_t> processed_streams_{}; | std::set<uint32_t> processed_streams_{}; | ||||
| std::vector<uint32_t> need_first_active_streams_{}; | std::vector<uint32_t> need_first_active_streams_{}; | ||||
| std::set<CNodeKey> independent_targets_; | |||||
| // attr for memory copy reuse | // attr for memory copy reuse | ||||
| std::map<uint32_t, std::vector<uint32_t>> stream_relations_{}; | std::map<uint32_t, std::vector<uint32_t>> stream_relations_{}; | ||||
| @@ -34,8 +34,8 @@ static constexpr uint32_t kTupleTaskId = 0; | |||||
| static constexpr uint32_t kTupleStreamId = 1; | static constexpr uint32_t kTupleStreamId = 1; | ||||
| static constexpr uint32_t kTupleArgs = 2; | static constexpr uint32_t kTupleArgs = 2; | ||||
| static constexpr uint32_t kCurrentStepTensorIndex = 0; | static constexpr uint32_t kCurrentStepTensorIndex = 0; | ||||
| static constexpr uint32_t kCurrentEpochTensorIndex = 2; | |||||
| static constexpr uint32_t kStepsPerEpochTensorIndex = 3; | |||||
| static constexpr uint32_t kCurrentEpochTensorIndex = 1; | |||||
| static constexpr uint32_t kStepsPerEpochTensorIndex = 2; | |||||
| static constexpr uint64_t kOpDebugShape = 2048; | static constexpr uint64_t kOpDebugShape = 2048; | ||||
| static constexpr uint64_t kOpDebugHostMemSize = 2048; | static constexpr uint64_t kOpDebugHostMemSize = 2048; | ||||
| static constexpr uint64_t kOpDebugDevMemSize = sizeof(void *); | static constexpr uint64_t kOpDebugDevMemSize = sizeof(void *); | ||||
| @@ -106,19 +106,6 @@ CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr<session::Kern | |||||
| return recv_node_ptr; | return recv_node_ptr; | ||||
| } | } | ||||
| bool KernelAdjust::ExitIndependent(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | |||||
| const auto &exe_orders = kernel_graph_ptr->execution_order(); | |||||
| for (const auto &node : exe_orders) { | |||||
| if (AnfAlgo::IsIndependentNode(node)) { | |||||
| MS_LOG(INFO) << "graph exit independent node"; | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | ||||
| device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance(); | device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance(); | ||||
| resource_manager.ResetResource(); | resource_manager.ResetResource(); | ||||
| @@ -133,10 +120,10 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||||
| std::vector<AnfNodePtr> *mute_inputs = kernel_graph_ptr->MutableInputs(); | std::vector<AnfNodePtr> *mute_inputs = kernel_graph_ptr->MutableInputs(); | ||||
| MS_EXCEPTION_IF_NULL(mute_inputs); | MS_EXCEPTION_IF_NULL(mute_inputs); | ||||
| mute_inputs->push_back(switch_loop_input[kCurLoopCountParamName]); | |||||
| mute_inputs->push_back(switch_loop_input[kNextLoopCountParamName]); | |||||
| mute_inputs->push_back(switch_loop_input[kLoopCountParamName]); | |||||
| mute_inputs->push_back(switch_loop_input[kEpochParamName]); | mute_inputs->push_back(switch_loop_input[kEpochParamName]); | ||||
| mute_inputs->push_back(switch_loop_input[kIterLoopParamName]); | mute_inputs->push_back(switch_loop_input[kIterLoopParamName]); | ||||
| mute_inputs->push_back(switch_loop_input[kZeroParamName]); | |||||
| mute_inputs->push_back(switch_loop_input[kOneParamName]); | mute_inputs->push_back(switch_loop_input[kOneParamName]); | ||||
| for (const auto &input : kernel_graph_ptr->inputs()) { | for (const auto &input : kernel_graph_ptr->inputs()) { | ||||
| MS_EXCEPTION_IF_NULL(input); | MS_EXCEPTION_IF_NULL(input); | ||||
| @@ -161,7 +148,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||||
| // getnext loop process | // getnext loop process | ||||
| // getnext loop stream switch op | // getnext loop stream switch op | ||||
| CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kGetNextStreamSwitch); | |||||
| CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); | |||||
| MS_EXCEPTION_IF_NULL(getnext_switch_app); | MS_EXCEPTION_IF_NULL(getnext_switch_app); | ||||
| uint32_t getnext_switch_stream_id = resource_manager.ApplyNewStream(); | uint32_t getnext_switch_stream_id = resource_manager.ApplyNewStream(); | ||||
| AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get()); | AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get()); | ||||
| @@ -181,9 +168,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||||
| } | } | ||||
| // update getnext loop stream switch true_branch_stream attr | // update getnext loop stream switch true_branch_stream attr | ||||
| AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), getnext_switch_app); | |||||
| AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(getnext_stream_id), getnext_switch_app); | AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(getnext_stream_id), getnext_switch_app); | ||||
| AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kGetNextStreamSwitch), getnext_switch_app); | |||||
| // getnext loop fpbp start send | // getnext loop fpbp start send | ||||
| uint32_t fpbp_start_event_id = resource_manager.ApplyNewEvent(); | uint32_t fpbp_start_event_id = resource_manager.ApplyNewEvent(); | ||||
| @@ -200,7 +185,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||||
| // End Of Sequence loop process | // End Of Sequence loop process | ||||
| // eos loop stream switch | // eos loop stream switch | ||||
| CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kEosStreamSwitch); | |||||
| CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); | |||||
| MS_EXCEPTION_IF_NULL(eos_switch_app); | MS_EXCEPTION_IF_NULL(eos_switch_app); | ||||
| uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream(); | uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream(); | ||||
| AnfAlgo::SetStreamId(eos_switch_stream_id, eos_switch_app.get()); | AnfAlgo::SetStreamId(eos_switch_stream_id, eos_switch_app.get()); | ||||
| @@ -215,7 +200,6 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||||
| // update eos loop stream switch true_branch_stream attr | // update eos loop stream switch true_branch_stream attr | ||||
| AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(eos_stream_id), eos_switch_app); | AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(eos_stream_id), eos_switch_app); | ||||
| AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kEosStreamSwitch), eos_switch_app); | |||||
| // EndOfSequence op | // EndOfSequence op | ||||
| CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode); | CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode); | ||||
| @@ -233,27 +217,13 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||||
| fpbp_active_streams.push_back(eos_switch_stream_id); | fpbp_active_streams.push_back(eos_switch_stream_id); | ||||
| } | } | ||||
| bool exit_independent = ExitIndependent(kernel_graph_ptr); | |||||
| if (exit_independent) { | |||||
| // Independet parallel | |||||
| CNodePtr independent_switch_app = | |||||
| CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kIndependentStreamSwitch); | |||||
| MS_EXCEPTION_IF_NULL(independent_switch_app); | |||||
| uint32_t independent_switch_stream_id = resource_manager.ApplyNewStream(); | |||||
| AnfAlgo::SetStreamId(independent_switch_stream_id, independent_switch_app.get()); | |||||
| AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), independent_switch_app); | |||||
| AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kIndependentStreamSwitch), independent_switch_app); | |||||
| exec_order.push_back(independent_switch_app); | |||||
| } | |||||
| // fpbp loop process | // fpbp loop process | ||||
| // fpbp loop stream switch | // fpbp loop stream switch | ||||
| CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kFpBpStreamSwitch); | |||||
| CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); | |||||
| MS_EXCEPTION_IF_NULL(fpbp_switch_app); | MS_EXCEPTION_IF_NULL(fpbp_switch_app); | ||||
| uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream(); | uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream(); | ||||
| AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get()); | AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get()); | ||||
| AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app); | AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app); | ||||
| exec_order.push_back(fpbp_switch_app); | exec_order.push_back(fpbp_switch_app); | ||||
| // fpbp loop fpbp start recv | // fpbp loop fpbp start recv | ||||
| @@ -264,9 +234,9 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||||
| // update fpbp loop stream switch true_branch_stream attr | // update fpbp loop stream switch true_branch_stream attr | ||||
| AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(fpbp_stream_id), fpbp_switch_app); | AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(fpbp_stream_id), fpbp_switch_app); | ||||
| AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kFpBpStreamSwitch), fpbp_switch_app); | |||||
| // next loop AssignAdd | |||||
| CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input, false); | |||||
| // fpbp loop AssignAdd | |||||
| CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input); | |||||
| MS_EXCEPTION_IF_NULL(assign_add_one); | MS_EXCEPTION_IF_NULL(assign_add_one); | ||||
| AnfAlgo::SetStreamId(fpbp_stream_id, assign_add_one.get()); | AnfAlgo::SetStreamId(fpbp_stream_id, assign_add_one.get()); | ||||
| exec_order.push_back(assign_add_one); | exec_order.push_back(assign_add_one); | ||||
| @@ -304,11 +274,6 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||||
| // fpbp loop other ops | // fpbp loop other ops | ||||
| (void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order)); | (void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order)); | ||||
| // current assign add op | |||||
| CNodePtr cur_assign_add = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input, true); | |||||
| MS_EXCEPTION_IF_NULL(cur_assign_add); | |||||
| exec_order.push_back(cur_assign_add); | |||||
| // stream active to activate fpbp loop and eos loop | // stream active to activate fpbp loop and eos loop | ||||
| CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr); | CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr); | ||||
| MS_EXCEPTION_IF_NULL(fpbp_active_app); | MS_EXCEPTION_IF_NULL(fpbp_active_app); | ||||
| @@ -331,19 +296,13 @@ void KernelAdjust::CreateSwitchOpParameters(const std::shared_ptr<session::Kerne | |||||
| MS_LOG(EXCEPTION) << "create abstract before insert switch op failed!"; | MS_LOG(EXCEPTION) << "create abstract before insert switch op failed!"; | ||||
| } | } | ||||
| ParameterPtr cur_loop_count = std::make_shared<Parameter>(kernel_graph_ptr); | |||||
| MS_EXCEPTION_IF_NULL(cur_loop_count); | |||||
| cur_loop_count->set_name(kCurLoopCountParamName); | |||||
| cur_loop_count->set_abstract(paremeter_abstract_ptr); | |||||
| ParameterPtr loop_count_cur = kernel_graph_ptr->NewParameter(cur_loop_count); | |||||
| (*switch_loop_input)[kCurLoopCountParamName] = loop_count_cur; | |||||
| ParameterPtr loop_count = std::make_shared<Parameter>(kernel_graph_ptr); | |||||
| MS_EXCEPTION_IF_NULL(loop_count); | |||||
| loop_count->set_name(kLoopCountParamName); | |||||
| loop_count->set_abstract(paremeter_abstract_ptr); | |||||
| ParameterPtr loop_count_new = kernel_graph_ptr->NewParameter(loop_count); | |||||
| ParameterPtr next_loop_count = std::make_shared<Parameter>(kernel_graph_ptr); | |||||
| MS_EXCEPTION_IF_NULL(next_loop_count); | |||||
| next_loop_count->set_name(kNextLoopCountParamName); | |||||
| next_loop_count->set_abstract(paremeter_abstract_ptr); | |||||
| ParameterPtr loop_count_next = kernel_graph_ptr->NewParameter(next_loop_count); | |||||
| (*switch_loop_input)[kNextLoopCountParamName] = loop_count_next; | |||||
| (*switch_loop_input)[kLoopCountParamName] = loop_count_new; | |||||
| ParameterPtr iter_loop = std::make_shared<Parameter>(kernel_graph_ptr); | ParameterPtr iter_loop = std::make_shared<Parameter>(kernel_graph_ptr); | ||||
| iter_loop->set_name(kIterLoopParamName); | iter_loop->set_name(kIterLoopParamName); | ||||
| @@ -351,6 +310,12 @@ void KernelAdjust::CreateSwitchOpParameters(const std::shared_ptr<session::Kerne | |||||
| ParameterPtr iter_loop_new = kernel_graph_ptr->NewParameter(iter_loop); | ParameterPtr iter_loop_new = kernel_graph_ptr->NewParameter(iter_loop); | ||||
| (*switch_loop_input)[kIterLoopParamName] = iter_loop_new; | (*switch_loop_input)[kIterLoopParamName] = iter_loop_new; | ||||
| ParameterPtr zero = std::make_shared<Parameter>(kernel_graph_ptr); | |||||
| zero->set_name(kZeroParamName); | |||||
| zero->set_abstract(paremeter_abstract_ptr); | |||||
| ParameterPtr zero_new = kernel_graph_ptr->NewParameter(zero); | |||||
| (*switch_loop_input)[kZeroParamName] = zero_new; | |||||
| ParameterPtr one = std::make_shared<Parameter>(kernel_graph_ptr); | ParameterPtr one = std::make_shared<Parameter>(kernel_graph_ptr); | ||||
| one->set_name(kOneParamName); | one->set_name(kOneParamName); | ||||
| one->set_abstract(paremeter_abstract_ptr); | one->set_abstract(paremeter_abstract_ptr); | ||||
| @@ -378,22 +343,14 @@ kernel::KernelBuildInfo::KernelBuildInfoBuilder KernelAdjust::CreateMngKernelBui | |||||
| } | } | ||||
| CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | ||||
| const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input, | |||||
| StreamSwitchKind kind) { | |||||
| const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input) { | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( | kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( | ||||
| {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); | {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); | ||||
| auto typeNone_abstract = std::make_shared<abstract::AbstractNone>(); | auto typeNone_abstract = std::make_shared<abstract::AbstractNone>(); | ||||
| auto stream_switch = std::make_shared<Primitive>(kStreamSwitchOpName); | auto stream_switch = std::make_shared<Primitive>(kStreamSwitchOpName); | ||||
| std::vector<AnfNodePtr> inputs; | std::vector<AnfNodePtr> inputs; | ||||
| inputs.push_back(NewValueNode(stream_switch)); | inputs.push_back(NewValueNode(stream_switch)); | ||||
| if (kind == kFpBpStreamSwitch || kind == kEosStreamSwitch) { | |||||
| inputs.push_back(switch_loop_input.at(kCurLoopCountParamName)); | |||||
| } else if (kind == kGetNextStreamSwitch || kind == kIndependentStreamSwitch) { | |||||
| inputs.push_back(switch_loop_input.at(kNextLoopCountParamName)); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "unknown stream switch kind"; | |||||
| } | |||||
| inputs.push_back(switch_loop_input.at(kLoopCountParamName)); | |||||
| inputs.push_back(switch_loop_input.at(kIterLoopParamName)); | inputs.push_back(switch_loop_input.at(kIterLoopParamName)); | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | ||||
| CNodePtr stream_switch_app = kernel_graph_ptr->NewCNode(inputs); | CNodePtr stream_switch_app = kernel_graph_ptr->NewCNode(inputs); | ||||
| @@ -476,9 +433,9 @@ CNodePtr KernelAdjust::CreateEndOfSequenceOP(const std::shared_ptr<session::Kern | |||||
| return end_of_sequence_node; | return end_of_sequence_node; | ||||
| } | } | ||||
| CNodePtr KernelAdjust::CreateStreamAssignAddnOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | |||||
| const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input, | |||||
| bool cur_loop) { | |||||
| CNodePtr KernelAdjust::CreateStreamAssignAddnOP( | |||||
| const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | |||||
| const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | ||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( | kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( | ||||
| {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); | {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); | ||||
| @@ -488,12 +445,7 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP(const std::shared_ptr<session::K | |||||
| auto assign_add = std::make_shared<Primitive>(kAssignAddOpName); | auto assign_add = std::make_shared<Primitive>(kAssignAddOpName); | ||||
| std::vector<AnfNodePtr> inputs; | std::vector<AnfNodePtr> inputs; | ||||
| inputs.push_back(NewValueNode(assign_add)); | inputs.push_back(NewValueNode(assign_add)); | ||||
| if (cur_loop) { | |||||
| inputs.push_back(switch_loop_input.at(kCurLoopCountParamName)); | |||||
| } else { | |||||
| inputs.push_back(switch_loop_input.at(kNextLoopCountParamName)); | |||||
| } | |||||
| inputs.push_back(switch_loop_input.at(kLoopCountParamName)); | |||||
| inputs.push_back(switch_loop_input.at(kOneParamName)); | inputs.push_back(switch_loop_input.at(kOneParamName)); | ||||
| CNodePtr assign_add_one = kernel_graph_ptr->NewCNode(inputs); | CNodePtr assign_add_one = kernel_graph_ptr->NewCNode(inputs); | ||||
| MS_EXCEPTION_IF_NULL(assign_add_one); | MS_EXCEPTION_IF_NULL(assign_add_one); | ||||
| @@ -505,8 +457,8 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP(const std::shared_ptr<session::K | |||||
| AnfAlgo::SetNodeAttr("input_names", input_names_v, assign_add_one); | AnfAlgo::SetNodeAttr("input_names", input_names_v, assign_add_one); | ||||
| AnfAlgo::SetNodeAttr("output_names", output_names_v, assign_add_one); | AnfAlgo::SetNodeAttr("output_names", output_names_v, assign_add_one); | ||||
| selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL); | selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL); | ||||
| MS_EXCEPTION_IF_NULL(switch_loop_input.at(kCurLoopCountParamName)); | |||||
| assign_add_one->set_abstract(switch_loop_input.at(kCurLoopCountParamName)->abstract()); | |||||
| MS_EXCEPTION_IF_NULL(switch_loop_input.at(kLoopCountParamName)); | |||||
| assign_add_one->set_abstract(switch_loop_input.at(kLoopCountParamName)->abstract()); | |||||
| return assign_add_one; | return assign_add_one; | ||||
| } | } | ||||
| @@ -561,23 +513,14 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph | |||||
| void KernelAdjust::LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs) { | void KernelAdjust::LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs) { | ||||
| MS_LOG(INFO) << "---------------- LoadSwitchInputs---"; | MS_LOG(INFO) << "---------------- LoadSwitchInputs---"; | ||||
| MS_EXCEPTION_IF_NULL(inputs); | MS_EXCEPTION_IF_NULL(inputs); | ||||
| // current loop count | |||||
| std::vector<int> shp = {1}; | std::vector<int> shp = {1}; | ||||
| tensor::TensorPtr cur_loop_count = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | |||||
| MS_EXCEPTION_IF_NULL(cur_loop_count); | |||||
| tensor::TensorPtr loop_count_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | |||||
| MS_EXCEPTION_IF_NULL(loop_count_tensor); | |||||
| int32_t *val = nullptr; | int32_t *val = nullptr; | ||||
| val = static_cast<int32_t *>(cur_loop_count->data_c()); | |||||
| val = static_cast<int32_t *>(loop_count_tensor->data_c()); | |||||
| MS_EXCEPTION_IF_NULL(val); | MS_EXCEPTION_IF_NULL(val); | ||||
| *val = 0; | *val = 0; | ||||
| inputs->push_back(cur_loop_count); | |||||
| // next loop count | |||||
| tensor::TensorPtr next_loop_count = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | |||||
| MS_EXCEPTION_IF_NULL(next_loop_count); | |||||
| val = static_cast<int32_t *>(next_loop_count->data_c()); | |||||
| MS_EXCEPTION_IF_NULL(val); | |||||
| *val = 0; | |||||
| inputs->push_back(next_loop_count); | |||||
| inputs->push_back(loop_count_tensor); | |||||
| // Epoch in device | // Epoch in device | ||||
| tensor::TensorPtr epoch_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | tensor::TensorPtr epoch_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | ||||
| @@ -587,7 +530,6 @@ void KernelAdjust::LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs) { | |||||
| *val = 0; | *val = 0; | ||||
| inputs->push_back(epoch_tensor); | inputs->push_back(epoch_tensor); | ||||
| // total loop count per iter | |||||
| tensor::TensorPtr iter_loop_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | tensor::TensorPtr iter_loop_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | ||||
| MS_EXCEPTION_IF_NULL(iter_loop_tensor); | MS_EXCEPTION_IF_NULL(iter_loop_tensor); | ||||
| val = static_cast<int32_t *>(iter_loop_tensor->data_c()); | val = static_cast<int32_t *>(iter_loop_tensor->data_c()); | ||||
| @@ -596,6 +538,13 @@ void KernelAdjust::LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs) { | |||||
| MS_LOG(INFO) << "iter_loop_tensor = " << *val; | MS_LOG(INFO) << "iter_loop_tensor = " << *val; | ||||
| inputs->push_back(iter_loop_tensor); | inputs->push_back(iter_loop_tensor); | ||||
| tensor::TensorPtr zero_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | |||||
| MS_EXCEPTION_IF_NULL(zero_tensor); | |||||
| val = static_cast<int32_t *>(zero_tensor->data_c()); | |||||
| MS_EXCEPTION_IF_NULL(val); | |||||
| *val = 0; | |||||
| inputs->push_back(zero_tensor); | |||||
| tensor::TensorPtr one_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | tensor::TensorPtr one_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | ||||
| MS_EXCEPTION_IF_NULL(one_tensor); | MS_EXCEPTION_IF_NULL(one_tensor); | ||||
| val = static_cast<int32_t *>(one_tensor->data_c()); | val = static_cast<int32_t *>(one_tensor->data_c()); | ||||
| @@ -33,19 +33,13 @@ | |||||
| using mindspore::device::ascend::ProfilingTraceInfo; | using mindspore::device::ascend::ProfilingTraceInfo; | ||||
| using mindspore::device::ascend::ProfilingUtils; | using mindspore::device::ascend::ProfilingUtils; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| constexpr auto kCurLoopCountParamName = "cur_loop_count"; | |||||
| constexpr auto kNextLoopCountParamName = "next_loop_count"; | |||||
| constexpr auto kLoopCountParamName = "loop_count"; | |||||
| constexpr auto kIterLoopParamName = "iter_loop"; | constexpr auto kIterLoopParamName = "iter_loop"; | ||||
| constexpr auto kZeroParamName = "zero"; | |||||
| constexpr auto kOneParamName = "one"; | constexpr auto kOneParamName = "one"; | ||||
| constexpr auto kEpochParamName = "loop_epoch"; | constexpr auto kEpochParamName = "loop_epoch"; | ||||
| constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; | constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; | ||||
| constexpr uint32_t kSecondStreamSwitchLabel = 2; | constexpr uint32_t kSecondStreamSwitchLabel = 2; | ||||
| enum StreamSwitchKind { | |||||
| kFpBpStreamSwitch = 0, | |||||
| kGetNextStreamSwitch = 1, | |||||
| kEosStreamSwitch = 2, | |||||
| kIndependentStreamSwitch = 3 | |||||
| }; | |||||
| namespace device { | namespace device { | ||||
| class KernelAdjust { | class KernelAdjust { | ||||
| @@ -71,22 +65,18 @@ class KernelAdjust { | |||||
| void CreateSwitchOpParameters(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | void CreateSwitchOpParameters(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | ||||
| std::map<std::string, mindspore::ParameterPtr> *switch_loop_input); | std::map<std::string, mindspore::ParameterPtr> *switch_loop_input); | ||||
| CNodePtr CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | CNodePtr CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | ||||
| const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input, | |||||
| StreamSwitchKind kind); | |||||
| const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input); | |||||
| CNodePtr CreatTupleGetItemNode(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, const CNodePtr &node, | CNodePtr CreatTupleGetItemNode(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, const CNodePtr &node, | ||||
| size_t output_idx); | size_t output_idx); | ||||
| CNodePtr CreateEndOfSequenceOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | CNodePtr CreateEndOfSequenceOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | ||||
| const CNodePtr &getnext_cnode); | const CNodePtr &getnext_cnode); | ||||
| CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | ||||
| const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input, | |||||
| bool cur_loop); | |||||
| const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input); | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector<std::string> &formats, | kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector<std::string> &formats, | ||||
| const std::vector<TypeId> &type_ids); | const std::vector<TypeId> &type_ids); | ||||
| void LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs); | void LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs); | ||||
| void InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info, | void InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info, | ||||
| NotNull<session::KernelGraph *> kernel_graph_ptr); | NotNull<session::KernelGraph *> kernel_graph_ptr); | ||||
| bool ExitIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| }; | }; | ||||
| } // namespace device | } // namespace device | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -580,14 +580,6 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in | |||||
| MS_LOG(INFO) << "GetNext disable mem_reuse"; | MS_LOG(INFO) << "GetNext disable mem_reuse"; | ||||
| type = kDynamicMem; | type = kDynamicMem; | ||||
| } | } | ||||
| if (node->isa<CNode>()) { | |||||
| bool independent = AnfAlgo::IsIndependentNode(node->cast<CNodePtr>()); | |||||
| if (independent && type == kReuseDynamicMem) { | |||||
| MS_LOG(INFO) << "Independent disable mem_reuse"; | |||||
| type = kDynamicMem; | |||||
| } | |||||
| } | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(node); | auto kernel_mod = AnfAlgo::GetKernelMod(node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | MS_EXCEPTION_IF_NULL(kernel_mod); | ||||
| auto output_sizes = kernel_mod->GetOutputSizeList(); | auto output_sizes = kernel_mod->GetOutputSizeList(); | ||||
| @@ -210,7 +210,6 @@ constexpr auto kAttrDataType = "data_type"; | |||||
| constexpr auto kAttrActiveTarget = "active_target"; | constexpr auto kAttrActiveTarget = "active_target"; | ||||
| constexpr auto kAttrActiveStreamList = "active_stream_list"; | constexpr auto kAttrActiveStreamList = "active_stream_list"; | ||||
| constexpr auto kAttrTrueBranchStream = "true_branch_stream"; | constexpr auto kAttrTrueBranchStream = "true_branch_stream"; | ||||
| constexpr auto kAttrStreamSwitchKind = "stream_switch_kind"; | |||||
| constexpr auto kAttrEventId = "event_id"; | constexpr auto kAttrEventId = "event_id"; | ||||
| constexpr auto kAttrDynInput = "dynamic"; | constexpr auto kAttrDynInput = "dynamic"; | ||||
| constexpr auto kAttrDynInputSizes = "dyn_input_sizes"; | constexpr auto kAttrDynInputSizes = "dyn_input_sizes"; | ||||