From 90c6c6b420406eb708bdf109a60828607f35d42d Mon Sep 17 00:00:00 2001 From: zhengyuanhua Date: Mon, 4 Jan 2021 16:26:42 +0800 Subject: [PATCH 1/2] profiling training trace --- ge/common/profiling/profiling_manager.cc | 2 + ge/common/types.cc | 3 + ge/graph/build/graph_builder.cc | 52 +++++ ge/graph/build/graph_builder.h | 1 + ge/graph/build/task_generator.cc | 113 +++++++--- ge/graph/build/task_generator.h | 7 +- .../load/new_model_manager/davinci_model.cc | 14 +- .../load/new_model_manager/davinci_model.h | 2 + ge/hybrid/executor/worker/execution_engine.cc | 2 + ge/hybrid/model/hybrid_model_builder.cc | 196 +++++++++++++++++- ge/hybrid/model/hybrid_model_builder.h | 6 + .../node_executor/rts/rts_node_executor.cc | 33 +++ .../node_executor/rts/rts_node_executor.h | 13 ++ ge/hybrid/node_executor/task_context.h | 2 +- inc/framework/common/ge_types.h | 2 + inc/framework/common/types.h | 3 + metadef | 2 +- parser | 2 +- 18 files changed, 422 insertions(+), 33 deletions(-) diff --git a/ge/common/profiling/profiling_manager.cc b/ge/common/profiling/profiling_manager.cc index 92417286..aad2bbe3 100644 --- a/ge/common/profiling/profiling_manager.cc +++ b/ge/common/profiling/profiling_manager.cc @@ -302,6 +302,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin } data.append(" model_id:").append(std::to_string(model_id)); + data.append(" task_id:").append(std::to_string(graph.task_id)); + data.append(" stream_id:").append(std::to_string(graph.stream_id)); data.append("\n"); GraphDescReport(device_id, data); diff --git a/ge/common/types.cc b/ge/common/types.cc index 1cc70347..268e7caa 100644 --- a/ge/common/types.cc +++ b/ge/common/types.cc @@ -480,6 +480,9 @@ REGISTER_OPTYPE_DEFINE(HVDWAIT, "HorovodWait"); // aicpu op for online_infer dynamic_dims REGISTER_OPTYPE_DEFINE(GETDYNAMICDIMS, "GetDynamicDims"); +// profiling training trace node +REGISTER_OPTYPE_DEFINE(PROFILINGTRAININGTRACE, "ProfilingTrainingTrace"); + const std::string MODEL_ATTR_TASKS = "tasks"; const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR = "task_gen_weight_addr"; diff --git a/ge/graph/build/graph_builder.cc b/ge/graph/build/graph_builder.cc index dce40c3e..143d5550 100644 --- a/ge/graph/build/graph_builder.cc +++ b/ge/graph/build/graph_builder.cc @@ -421,6 +421,52 @@ static Status GenerateTaskForConstant(const std::shared_ptr &graph return SUCCESS; } +Status GraphBuilder::MarkFpBpProfilingTaskAttr(ComputeGraphPtr &com_graph) { + bool original_unknown_shape_flag = com_graph->GetGraphUnknownFlag(); + com_graph->SetGraphUnknownFlag(false); + + GELOGD("Start to mark profiling task attr for fp and bp."); + TaskGenerator task_generator; + ProfilingPoint profiling_point; + std::vector all_reduce_node_index; + Status ret = task_generator.FindProfilingNodeIndex(com_graph, profiling_point, all_reduce_node_index); + com_graph->SetGraphUnknownFlag(original_unknown_shape_flag); + if (ret != SUCCESS) { + GELOGW("Find profiling node index failed."); + } + if (profiling_point.fp_index == 0 || profiling_point.bp_index == 0 || profiling_point.end_index.empty()) { + GELOGD("No need to mark fp bp profiling task attr."); + return SUCCESS; + } + // mark profiling task attr for node + uint32_t node_index = 0; + for (const auto &node : com_graph->GetAllNodes()) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(node->GetOpDesc()); + node_index++; + if (profiling_point.fp_index == node_index) { + GELOGI("The first fp node of dynamic graph is %s, idx %u", op_desc->GetName().c_str(), node_index); + (void)ge::AttrUtils::SetBool(op_desc, ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); + } + if (profiling_point.bp_index == node_index) { + GELOGI("The bp node of dynamic graph is %s, idx %u", op_desc->GetName().c_str(), node_index); + (void)ge::AttrUtils::SetBool(op_desc, ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); + } + for (size_t i = 0; i < all_reduce_node_index.size(); i++) { + if (all_reduce_node_index[i] == node_index) { + GELOGI("The all reduce node of dynamic graph is %s, idx %u", op_desc->GetName().c_str(), node_index); + (void)ge::AttrUtils::SetBool(op_desc, ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); + continue; + } + } + if (profiling_point.end_index.find(node_index) != profiling_point.end_index.end()) { + GELOGI("The end node of dynamic graph is %s, idx %u", op_desc->GetName().c_str(), node_index); + (void)ge::AttrUtils::SetBool(op_desc, ATTR_NAME_INSERT_END_PROFILILNG_TASK, true); + } + } + return SUCCESS; +} + Status GraphBuilder::BuildForDynamicShapeGraph(ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, GeRootModelPtr &ge_root_model_ptr, GeModelPtr &ge_model_ptr, @@ -437,6 +483,12 @@ Status GraphBuilder::BuildForDynamicShapeGraph(ComputeGraphPtr &comp_graph, } } + // Set fp bp profiling task attr for graph + if (MarkFpBpProfilingTaskAttr(comp_graph) != SUCCESS) { + GELOGE(FAILED, "Set fp bp profiling task attr for graph."); + return FAILED; + } + auto all_graphs = comp_graph->GetAllSubgraphs(); if (all_graphs.empty()) { all_graphs.push_back(comp_graph); diff --git a/ge/graph/build/graph_builder.h b/ge/graph/build/graph_builder.h index b828a80d..524b60e0 100644 --- a/ge/graph/build/graph_builder.h +++ b/ge/graph/build/graph_builder.h @@ -60,6 +60,7 @@ class GraphBuilder { Status UpdateParentNodeOutputSize(const ge::ComputeGraphPtr &graph, ge::NodePtr &parent_node_ptr); Status CalcDynShapeRootGraphDataSize(const ge::OpDescPtr &op_desc); Status SecondPartition(ge::ComputeGraphPtr &comp_graph, vector &subgraph_ptr_list); + Status MarkFpBpProfilingTaskAttr(ComputeGraphPtr &com_graph); Status BuildForDynamicShapeGraph(ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, GeRootModelPtr &ge_root_model_ptr, GeModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); diff --git a/ge/graph/build/task_generator.cc b/ge/graph/build/task_generator.cc index 7e45ad61..21e82d11 100755 --- a/ge/graph/build/task_generator.cc +++ b/ge/graph/build/task_generator.cc @@ -274,6 +274,7 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra }; GE_MAKE_GUARD(release, callback); + uint64_t all_reduce_node_idx = 0; for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -292,7 +293,7 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra // Part2: Call auto fusion_task_info = FusionTaskInfo{run_context, graph, node, op_desc, node_index, ge_lib, - ops_kernel_manager, task_def_list, op_name_map, profiling_point, all_reduce_nodes}; + ops_kernel_manager, task_def_list, op_name_map, profiling_point, all_reduce_nodes, all_reduce_node_idx}; GE_CHK_STATUS_RET(GenerateTaskForFusionNode(fusion_task_info, fusion_nodes, fusion_nodes_seen), "Call GenerateTaskForFusionNode node:%s(%s) failed", name.c_str(), type.c_str()); // continue directly @@ -316,7 +317,8 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra type.c_str()); // Profiling task size_t task_list_size_before = task_def_list.size(); - GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); + GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, + node_index, task_def_list, all_reduce_node_idx)); int64_t op_id = op_desc->GetId(); // Compatible with dynamic shape scenes, the default is 0 int64_t stream_id = 0; @@ -336,8 +338,8 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra return ret; } // Profiling task - GE_CHK_STATUS_RET(InsertProfilingTaskAfter(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); - + GE_CHK_STATUS_RET(InsertProfilingTaskAfter(op_desc, profiling_point, all_reduce_nodes, + node_index, task_def_list, all_reduce_node_idx)); size_t task_list_size_after = task_def_list.size(); // If tasks is reduced if (task_list_size_after < task_list_size_before) { @@ -380,6 +382,7 @@ Status TaskGenerator::GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info auto &op_name_map = fusion_task_info.op_name_map; auto &profiling_point = fusion_task_info.profiling_point; auto &all_reduce_nodes = fusion_task_info.all_reduce_nodes; + auto &all_reduce_idx = fusion_task_info.all_reduce_node_idx; // If op_desc have this attr, call nodes with same group key in a stream together if (ge::AttrUtils::GetInt(fusion_op_desc, ATTR_NAME_FUSION_GROUP_KEY, group_key) && (fusion_nodes_seen.count(node.get()) == 0)) { @@ -426,7 +429,8 @@ Status TaskGenerator::GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info return INTERNAL_ERROR; } // profiling task - (void)InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list); + (void)InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, + node_index, task_def_list, all_reduce_idx); run_context.stream = run_context.graphStreamList[stream_id]; GELOGI("Fusion: Call %s to generate fusion_node:[fusion_node_name:%s(%s), id:%ld, stream_id:%ld] task.", op_kernel_lib_name.c_str(), fusion_node_name.c_str(), fusion_node_type.c_str(), op_id, stream_id); @@ -439,7 +443,8 @@ Status TaskGenerator::GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info return ret; } // profiling task - (void)InsertProfilingTaskAfter(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list); + (void)InsertProfilingTaskAfter(op_desc, profiling_point, all_reduce_nodes, + node_index, task_def_list, all_reduce_idx); size_t task_list_size_after = task_def_list.size(); // if tasks is reduced if (task_list_size_after < task_list_size_before) { @@ -830,6 +835,11 @@ Status TaskGenerator::GetFpBpIndex(const ComputeGraphPtr &graph, ProfilingPoint return SUCCESS; } +Status TaskGenerator::FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, + std::vector &all_reduce_nodes) { + return FindProfilingTaskIndex(graph, profiling_point, all_reduce_nodes); +} + Status TaskGenerator::FindProfilingTaskIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, vector &all_reduce_nodes) const { GE_CHECK_NOTNULL(graph); @@ -840,7 +850,6 @@ Status TaskGenerator::FindProfilingTaskIndex(const ComputeGraphPtr &graph, Profi GELOGD("Profiling is not open."); return SUCCESS; } - GELOGI("Start get FP/BP index."); std::string fp_point_str; std::string bp_point_str; @@ -878,18 +887,27 @@ Status TaskGenerator::FindProfilingTaskIndex(const ComputeGraphPtr &graph, Profi return SUCCESS; } - Status TaskGenerator::InsertProfilingTaskBefore(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point, vector &all_reduce_nodes, uint32_t node_index, - vector &task_def_list) { + vector &task_def_list, uint64_t &all_reduce_node_idx) { const char *profiling_mode = std::getenv(kProfilingMode); bool is_profiling = (profiling_mode != nullptr) || ProfilingManager::Instance().ProfilingOn() || ProfilingManager::Instance().ProfilingTrainingTraceOn(); - if (!is_profiling || (profiling_point.fp_index == 0) || (profiling_point.bp_index == 0) || - (profiling_point.end_index.empty())) { + bool is_insert_fp_profiling_task = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_INSERT_FP_PROFILILNG_TASK, is_insert_fp_profiling_task); + bool is_insert_bp_profiling_task = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_INSERT_BP_PROFILILNG_TASK, is_insert_bp_profiling_task); + bool no_insert_profiling_task = ((profiling_point.fp_index == 0) || (profiling_point.bp_index == 0) || + (profiling_point.end_index.empty())) && + (!(is_insert_fp_profiling_task || is_insert_bp_profiling_task)); + if (!is_profiling || no_insert_profiling_task) { return SUCCESS; } - if (profiling_point.fp_index == node_index) { + GELOGD("Insert fp profiling task: %d, insert bp profiling task: %d, fp index: %u, bp index: %u, end index size: %zu", + is_insert_fp_profiling_task, is_insert_bp_profiling_task, profiling_point.fp_index, profiling_point.bp_index, + profiling_point.end_index.size()); + + if ((profiling_point.fp_index == node_index) || is_insert_fp_profiling_task) { uint64_t jobid_log_id = ge::GetContext().TraceId(); GELOGI("The first FP operator is %s, idx %u, job_id %lu", op_desc->GetName().c_str(), node_index, jobid_log_id); @@ -913,22 +931,40 @@ Status TaskGenerator::InsertProfilingTaskBefore(const OpDescPtr &op_desc, const task_def_list.emplace_back(fp_task_def); } - for (size_t i = 0; i < all_reduce_nodes.size(); i++) { - if (all_reduce_nodes[i] != node_index) { - continue; + bool is_all_reduce = (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HVDCALLBACKALLREDUCE); + uint64_t all_reduce_task_idx = 0; + bool is_insert_all_reduce_task = false; + if (is_all_reduce && is_insert_bp_profiling_task) { + all_reduce_task_idx = all_reduce_node_idx; + is_insert_all_reduce_task = true; + } + if (is_all_reduce) { + all_reduce_node_idx++; + } + if (!is_insert_all_reduce_task) { + for (size_t i = 0; i < all_reduce_nodes.size(); i++) { + if (all_reduce_nodes[i] == node_index) { + all_reduce_task_idx = i; + is_insert_all_reduce_task = true; + break; + } } + } + + if (is_insert_all_reduce_task) { GELOGI("The start allreduce operator is %s, idx %u", op_desc->GetName().c_str(), node_index); TaskDef ar_task_def; ar_task_def.set_type(RT_MODEL_TASK_PROFILER_TRACE); ar_task_def.set_stream_id(op_desc->GetStreamId()); LogTimeStampDef *ar_log_def = ar_task_def.mutable_log_timestamp(); if (ar_log_def != nullptr) { - GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(i, kProfilingArStep), + GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(all_reduce_task_idx, kProfilingArStep), GELOGE(FAILED, "Multiply result is out of range."); return FAILED); - auto log_id = i * kProfilingArStep + kProfilingArStartLogid; + auto log_id = all_reduce_task_idx * kProfilingArStep + kProfilingArStartLogid; ar_log_def->set_logid(log_id); ar_log_def->set_notify(false); + (void)ge::AttrUtils::SetInt(op_desc, ATTR_NAME_INSERT_PROFILILNG_TASK_LOG_ID, log_id); } task_def_list.push_back(ar_task_def); } @@ -937,16 +973,27 @@ Status TaskGenerator::InsertProfilingTaskBefore(const OpDescPtr &op_desc, const Status TaskGenerator::InsertProfilingTaskAfter(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point, vector &all_reduce_nodes, uint32_t node_index, - vector &task_def_list) { + vector &task_def_list, uint64_t all_reduce_node_idx) { GE_CHECK_NOTNULL(op_desc); const char *profiling_mode = std::getenv(kProfilingMode); bool is_profiling = (profiling_mode != nullptr) || ProfilingManager::Instance().ProfilingOn() || ProfilingManager::Instance().ProfilingTrainingTraceOn(); - if (!is_profiling || (profiling_point.fp_index == 0) || (profiling_point.bp_index == 0) || - (profiling_point.end_index.empty())) { + bool is_insert_bp_profiling_task = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_INSERT_BP_PROFILILNG_TASK, is_insert_bp_profiling_task); + bool is_insert_end_profiling_task = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_INSERT_END_PROFILILNG_TASK, is_insert_end_profiling_task); + bool no_insert_profiling_task = ((profiling_point.fp_index == 0) || (profiling_point.bp_index == 0) || + (profiling_point.end_index.empty())) && + (!(is_insert_bp_profiling_task || is_insert_end_profiling_task)); + if (!is_profiling || no_insert_profiling_task) { return SUCCESS; } - if (profiling_point.bp_index == node_index) { + GELOGD("Insert bp profiling task: %d, insert end profiling task: %d, fp index: %u, bp index: %u, end index size: %zu", + is_insert_bp_profiling_task, is_insert_end_profiling_task, profiling_point.fp_index, profiling_point.bp_index, + profiling_point.end_index.size() ); + + bool is_all_reduce = (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HVDCALLBACKALLREDUCE); + if ((profiling_point.bp_index == node_index) || (!is_all_reduce && is_insert_bp_profiling_task)) { GELOGI("The last BP operator is %s, idx %u", op_desc->GetName().c_str(), node_index); TaskDef bp_task_def; bp_task_def.set_type(RT_MODEL_TASK_PROFILER_TRACE); @@ -957,7 +1004,9 @@ Status TaskGenerator::InsertProfilingTaskAfter(const OpDescPtr &op_desc, const P bp_log_def->set_notify(false); task_def_list.emplace_back(bp_task_def); } - if (profiling_point.end_index.find(node_index) != profiling_point.end_index.end()) { + + if (profiling_point.end_index.find(node_index) != profiling_point.end_index.end() || + is_insert_end_profiling_task) { GELOGI("The iteration end operator is %s, idx %u", op_desc->GetName().c_str(), node_index); TaskDef end_task_def; end_task_def.set_type(RT_MODEL_TASK_PROFILER_TRACE); @@ -969,20 +1018,32 @@ Status TaskGenerator::InsertProfilingTaskAfter(const OpDescPtr &op_desc, const P task_def_list.emplace_back(end_task_def); } + uint32_t all_reduce_task_idx = 0; + bool is_insert_all_reduce_task = false; + if (is_all_reduce && is_insert_bp_profiling_task) { + all_reduce_task_idx = all_reduce_node_idx; + is_insert_all_reduce_task = true; + } + for (size_t i = 0; i < all_reduce_nodes.size(); i++) { - if (all_reduce_nodes[i] != node_index) { - continue; + if (all_reduce_nodes[i] == node_index) { + all_reduce_task_idx = i; + is_insert_all_reduce_task = true; + break; } + } + + if (is_insert_all_reduce_task) { GELOGI("The end allreduce operator is %s, idx %u", op_desc->GetName().c_str(), node_index); TaskDef ar_task_def; ar_task_def.set_type(RT_MODEL_TASK_PROFILER_TRACE); ar_task_def.set_stream_id(op_desc->GetStreamId()); LogTimeStampDef *ar_log_def = ar_task_def.mutable_log_timestamp(); GE_CHECK_NOTNULL(ar_log_def); - GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(i, kProfilingArStep), + GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(all_reduce_task_idx, kProfilingArStep), GELOGE(FAILED, "Multiply result is out of range."); return FAILED); - auto log_id = i * kProfilingArStep + kProfilingArEndLogid; + auto log_id = all_reduce_task_idx * kProfilingArStep + kProfilingArEndLogid; ar_log_def->set_logid(log_id); ar_log_def->set_notify(false); task_def_list.emplace_back(ar_task_def); diff --git a/ge/graph/build/task_generator.h b/ge/graph/build/task_generator.h index c93b2007..5970954c 100755 --- a/ge/graph/build/task_generator.h +++ b/ge/graph/build/task_generator.h @@ -51,6 +51,7 @@ struct FusionTaskInfo { std::map &op_name_map; ProfilingPoint &profiling_point; vector all_reduce_nodes; + uint64_t all_reduce_node_idx; }; class TaskGenerator { @@ -76,6 +77,8 @@ class TaskGenerator { /// Status GetTaskInfo(Model &model, ComputeGraphPtr &graph, uint64_t session_id, RunContext &run_context); + Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, + std::vector &all_reduce_nodes); private: Status UpdateAnchorStatus(const NodePtr &node); @@ -126,10 +129,10 @@ class TaskGenerator { std::vector &all_reduce_nodes) const; Status InsertProfilingTaskBefore(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point, std::vector &all_reduce_nodes, uint32_t node_index, - std::vector &task_def_list); + std::vector &task_def_list, uint64_t &all_reduce_node_idx); Status InsertProfilingTaskAfter(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point, std::vector &all_reduce_nodes, uint32_t node_index, - std::vector &task_def_list); + std::vector &task_def_list, uint64_t all_reduce_node_idx); static bool IsProfPoint(const OpDescPtr &op, const std::string &name); diff --git a/ge/graph/load/new_model_manager/davinci_model.cc b/ge/graph/load/new_model_manager/davinci_model.cc index f3d6f82b..0e97f9c0 100755 --- a/ge/graph/load/new_model_manager/davinci_model.cc +++ b/ge/graph/load/new_model_manager/davinci_model.cc @@ -3082,6 +3082,8 @@ Status DavinciModel::DistributeTask() { task_desc_info.stream_id = task->GetStreamId(); task_desc_info.shape_type = "static"; task_desc_info.cur_iter_num = 0; + profiler_report_op_info_[task_desc_info.op_name] = + std::pair(task_desc_info.task_id, task_desc_info.stream_id); task_desc_info_.emplace_back(task_desc_info); if (flag) { if (task->GetSktTaskID() != 0xFFFFFFFF) { @@ -3089,6 +3091,8 @@ Status DavinciModel::DistributeTask() { string op_name = "super_kernel_" + to_string(task_index); task_desc_info.op_name = op_name; task_desc_info.task_id = task->GetSktTaskID(); + profiler_report_op_info_[task_desc_info.op_name] = + std::pair(task_desc_info.task_id, task_desc_info.stream_id); task_desc_info_.emplace_back(task_desc_info); } } @@ -3960,7 +3964,15 @@ Status DavinciModel::GetComputeGraphInfo(vector &graph_des compute_graph_info.output_format = op_desc.output_format; compute_graph_info.output_shape = op_desc.output_shape; compute_graph_info.output_data_type = op_desc.output_data_type; - + uint32_t task_id = 0; + uint32_t stream_id = 0; + auto iter = profiler_report_op_info_.find(op_desc.op_name); + if (iter != profiler_report_op_info_.end()) { + task_id = iter->second.first; + stream_id = iter->second.second; + } + compute_graph_info.task_id = task_id; + compute_graph_info.stream_id = stream_id; graph_desc_info.emplace_back(compute_graph_info); } return SUCCESS; diff --git a/ge/graph/load/new_model_manager/davinci_model.h b/ge/graph/load/new_model_manager/davinci_model.h index 6b930b05..cb3902df 100755 --- a/ge/graph/load/new_model_manager/davinci_model.h +++ b/ge/graph/load/new_model_manager/davinci_model.h @@ -976,6 +976,8 @@ class DavinciModel { // for profiling task and graph info vector task_desc_info_; + std::map> profiler_report_op_info_; + int64_t maxDumpOpNum_; // for data dump DataDumper data_dumper_; diff --git a/ge/hybrid/executor/worker/execution_engine.cc b/ge/hybrid/executor/worker/execution_engine.cc index 21dd8e4b..e9c6ef29 100755 --- a/ge/hybrid/executor/worker/execution_engine.cc +++ b/ge/hybrid/executor/worker/execution_engine.cc @@ -221,6 +221,8 @@ Status NodeDoneCallback::GetGraphDescInfo(const NodePtr node, const HybridModel tmp_compute_graph_info.output_shape.emplace_back(output_desc.GetShape().GetDims()); tmp_compute_graph_info.output_data_type.emplace_back(output_desc.GetDataType()); } + tmp_compute_graph_info.task_id = context_->GetTaskId(); + tmp_compute_graph_info.stream_id = context_->GetStreamId(); compute_graph_info.emplace_back(tmp_compute_graph_info); GELOGD("GetComputeGraphInfo of node [%s] end.", node->GetName().c_str()); } diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 46c9c39b..32fc495a 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -35,11 +35,22 @@ namespace ge { namespace hybrid { +using domi::LogTimeStampDef; +using domi::TaskDef; namespace { const uint32_t kSubgraphIndex = 0U; const uint32_t kVarOutputIndex = 0U; +const uint64_t kProfilingFpStartLogid = 1U; +const uint64_t kProfilingBpEndLogid = 2U; +const uint64_t kProfilingIterEndLogid = 65535U; const int kBytes = 8; const char *const kOwnerGraphIsUnknown = "OwnerGraphIsUnknown"; +const char *const kProfilingGraph = "ProfilingGraph"; +const char *const kProfilingFpNode = "ProfilingFpNode"; +const char *const kProfilingBpNode = "ProfilingBpNode"; +const char *const kProfilingEndNode = "ProfilingEndNode"; +const char *const kProfilingArNode = "ProfilingAllReduceNode"; +const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; Status SetOutputNameAttr(ComputeGraph &graph) { vector output_names; @@ -1531,6 +1542,188 @@ Status HybridModelBuilder::RecoverGraphUnknownFlag() { return SUCCESS; } +Status HybridModelBuilder::GenerateFpProfilingTask(const OpDescPtr &op_desc, vector &task_def_list) { + uint64_t jobid_log_id = ge::GetContext().TraceId(); + GELOGD("The first FP operator is %s,, job_id %lu", op_desc->GetName().c_str(), jobid_log_id); + + TaskDef job_task_def; + job_task_def.set_type(RT_MODEL_TASK_PROFILER_TRACE); + job_task_def.set_stream_id(op_desc->GetStreamId()); + LogTimeStampDef *job_log_def = job_task_def.mutable_log_timestamp(); + if (job_log_def != nullptr) { + job_log_def->set_logid(jobid_log_id); + job_log_def->set_notify(false); + } + task_def_list.emplace_back(job_task_def); + TaskDef fp_task_def; + fp_task_def.set_type(RT_MODEL_TASK_PROFILER_TRACE); + fp_task_def.set_stream_id(op_desc->GetStreamId()); + LogTimeStampDef *fp_log_def = fp_task_def.mutable_log_timestamp(); + if (fp_log_def != nullptr) { + fp_log_def->set_logid(kProfilingFpStartLogid); + fp_log_def->set_notify(false); + } + task_def_list.emplace_back(fp_task_def); + + return SUCCESS; +} + +Status HybridModelBuilder::GenerateArProfilingTask(const OpDescPtr &op_desc, int64_t log_id, + vector &task_def_list) { + TaskDef ar_task_def; + ar_task_def.set_type(RT_MODEL_TASK_PROFILER_TRACE); + ar_task_def.set_stream_id(op_desc->GetStreamId()); + LogTimeStampDef *ar_log_def = ar_task_def.mutable_log_timestamp(); + if (ar_log_def != nullptr) { + ar_log_def->set_logid(log_id); + ar_log_def->set_notify(false); + } + task_def_list.emplace_back(ar_task_def); + + return SUCCESS; +} + +Status HybridModelBuilder::GenerateBpProfilingTask(const OpDescPtr &op_desc, vector &task_def_list) { + TaskDef bp_task_def; + bp_task_def.set_type(RT_MODEL_TASK_PROFILER_TRACE); + bp_task_def.set_stream_id(op_desc->GetStreamId()); + LogTimeStampDef *bp_log_def = bp_task_def.mutable_log_timestamp(); + GE_CHECK_NOTNULL(bp_log_def); + bp_log_def->set_logid(kProfilingBpEndLogid); + bp_log_def->set_notify(false); + task_def_list.emplace_back(bp_task_def); + + return SUCCESS; +} + +Status HybridModelBuilder::GenerateEndProfilingTask(const OpDescPtr &op_desc, vector &task_def_list) { + TaskDef end_task_def; + end_task_def.set_type(RT_MODEL_TASK_PROFILER_TRACE); + end_task_def.set_stream_id(op_desc->GetStreamId()); + LogTimeStampDef *end_log_def = end_task_def.mutable_log_timestamp(); + GE_CHECK_NOTNULL(end_log_def); + end_log_def->set_logid(kProfilingIterEndLogid); + end_log_def->set_notify(true); + task_def_list.emplace_back(end_task_def); + + return SUCCESS; +} + +Status HybridModelBuilder::CreateProfilingNodeBefore(GraphItem &graph_item, const NodePtr &node) { + GE_CHECK_NOTNULL(node); + const OpDescPtr &op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + const auto &compute_graph = MakeShared(kProfilingGraph); + GE_CHECK_NOTNULL(compute_graph); + + NodePtr node_ptr = nullptr; + vector task_def_list; + // create fp node + bool is_insert_fp_profiling_task = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_INSERT_FP_PROFILILNG_TASK, is_insert_fp_profiling_task); + if (is_insert_fp_profiling_task) { + (void)GenerateFpProfilingTask(op_desc, task_def_list); + auto fp_desc = MakeShared(kProfilingFpNode, PROFILINGTRAININGTRACE); + GE_CHECK_NOTNULL(fp_desc); + fp_desc->SetOpKernelLibName(kEngineNameRts); + node_ptr = compute_graph->AddNode(fp_desc); + GELOGD("Create fp profiling node success before."); + } + // creat all reduce start node + bool is_insert_bp_profiling_task = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_INSERT_BP_PROFILILNG_TASK, is_insert_bp_profiling_task); + bool is_all_reduce = (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HVDCALLBACKALLREDUCE); + if (is_all_reduce && is_insert_bp_profiling_task) { + int64_t log_id = 0; + (void)ge::AttrUtils::GetInt(op_desc, ATTR_NAME_INSERT_PROFILILNG_TASK_LOG_ID, log_id); + GELOGD("All reduce node profiling task log id: %ld before", log_id); + (void) GenerateArProfilingTask(op_desc, log_id, task_def_list); + string op_name = string(kProfilingArNode) + std::to_string(log_id); + auto ar_desc_start = MakeShared(op_name, PROFILINGTRAININGTRACE); + GE_CHECK_NOTNULL(ar_desc_start); + ar_desc_start->SetOpKernelLibName(kEngineNameRts); + node_ptr = compute_graph->AddNode(ar_desc_start); + GELOGD("Create all reduce start profiling node success before."); + } + + if (node_ptr != nullptr) { + for (const auto &task_def : task_def_list) { + hybrid_model_.task_defs_[node_ptr].emplace_back(task_def); + } + NodeItem *node_item = nullptr; + GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node_ptr, &node_item)); + node_item->input_start = 0; + node_item->output_start = 0; + graph_item.node_items_.emplace_back(node_item); + } else { + GELOGD("No need to create profiling node before."); + } + + return SUCCESS; +} + +Status HybridModelBuilder::CreateProfilingNodeAfter(GraphItem &graph_item, const NodePtr &node) { + GE_CHECK_NOTNULL(node); + const OpDescPtr &op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + const auto &compute_graph = MakeShared(kProfilingGraph); + GE_CHECK_NOTNULL(compute_graph); + + NodePtr node_ptr = nullptr; + vector task_def_list; + // Create all reduce end node + bool is_insert_bp_profiling_task = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_INSERT_BP_PROFILILNG_TASK, is_insert_bp_profiling_task); + bool is_all_reduce = (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HVDCALLBACKALLREDUCE); + if (is_all_reduce && is_insert_bp_profiling_task) { + int64_t log_id = 0; + (void)ge::AttrUtils::GetInt(op_desc, ATTR_NAME_INSERT_PROFILILNG_TASK_LOG_ID, log_id); + GELOGD("All reduce node profiling task log id: %ld after", log_id); + (void) GenerateArProfilingTask(op_desc, log_id + 1, task_def_list); + string op_name = string(kProfilingArNode) + std::to_string(log_id + 1); + auto ar_desc_end = MakeShared(op_name, PROFILINGTRAININGTRACE); + GE_CHECK_NOTNULL(ar_desc_end); + ar_desc_end->SetOpKernelLibName(kEngineNameRts); + node_ptr = compute_graph->AddNode(ar_desc_end); + GELOGD("Create all reduce end profiling node success after."); + } + // create bp node + if (!is_all_reduce && is_insert_bp_profiling_task) { + (void) GenerateBpProfilingTask(op_desc, task_def_list); + auto bp_op_desc = MakeShared(kProfilingBpNode, PROFILINGTRAININGTRACE); + GE_CHECK_NOTNULL(bp_op_desc); + bp_op_desc->SetOpKernelLibName(kEngineNameRts); + node_ptr = compute_graph->AddNode(bp_op_desc); + GELOGD("Create bp profiling node success after."); + } + // create end node + bool is_insert_end_profiling_task = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_INSERT_END_PROFILILNG_TASK, is_insert_end_profiling_task); + if (is_insert_end_profiling_task) { + (void)GenerateEndProfilingTask(op_desc, task_def_list); + auto end_desc = MakeShared(kProfilingEndNode, PROFILINGTRAININGTRACE); + GE_CHECK_NOTNULL(end_desc); + end_desc->SetOpKernelLibName(kEngineNameRts); + node_ptr = compute_graph->AddNode(end_desc); + GELOGD("Create end profiling node success after."); + } + + if (node_ptr != nullptr) { + for (const auto &task_def : task_def_list) { + hybrid_model_.task_defs_[node_ptr].emplace_back(task_def); + } + NodeItem *node_item = nullptr; + GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node_ptr, &node_item)); + node_item->input_start = 0; + node_item->output_start = 0; + graph_item.node_items_.emplace_back(node_item); + } else { + GELOGD("No need to create profiling node after."); + } + + return SUCCESS; +} + Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root_graph) { GELOGD("Start to load subgraph [%s]", graph.GetName().c_str()); // for known partitioned call, load all nodes @@ -1567,8 +1760,9 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root graph_item->output_node_ = node_item; GE_CHK_STATUS_RET_NOLOG(BuildOutputMapping(*graph_item, *node_item, is_root_graph)); } - + GE_CHK_STATUS_RET_NOLOG(CreateProfilingNodeBefore(*graph_item, node)); graph_item->node_items_.emplace_back(node_item); + GE_CHK_STATUS_RET_NOLOG(CreateProfilingNodeAfter(*graph_item, node)); // parse var outputs GE_CHK_STATUS_RET_NOLOG(ParseVarOutputs(*node_item)); GELOGD("NodeItem created: %s", node_item->DebugString().c_str()); diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index a11faae2..55a19b6c 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -79,6 +79,12 @@ class HybridModelBuilder { Status LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem *parent_node_item); Status RecoverGraphUnknownFlag(); Status CheckAicpuOpList(); + Status CreateProfilingNodeBefore(GraphItem &graph_item, const NodePtr &node); + Status CreateProfilingNodeAfter(GraphItem &graph_item, const NodePtr &node); + Status GenerateFpProfilingTask(const OpDescPtr &op_desc, vector &task_def_list); + Status GenerateBpProfilingTask(const OpDescPtr &op_desc, vector &task_def_list); + Status GenerateEndProfilingTask(const OpDescPtr &op_desc, vector &task_def_list); + Status GenerateArProfilingTask(const OpDescPtr &op_desc, int64_t log_id, vector &task_def_list); const char* GetGraphName() const { return hybrid_model_.model_name_.c_str(); diff --git a/ge/hybrid/node_executor/rts/rts_node_executor.cc b/ge/hybrid/node_executor/rts/rts_node_executor.cc index 18b875fd..90b623e0 100644 --- a/ge/hybrid/node_executor/rts/rts_node_executor.cc +++ b/ge/hybrid/node_executor/rts/rts_node_executor.cc @@ -18,6 +18,7 @@ #include "common/debug/log.h" #include "common/ge/ge_util.h" #include "graph/utils/tensor_utils.h" +#include "hybrid/model/hybrid_model.h" #include "runtime/rt.h" namespace ge { @@ -79,12 +80,44 @@ Status IdentityNNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + for (const auto &task_def : task_defs_) { + auto log_time_stamp_def = task_def.log_timestamp(); + uint64_t log_id = log_time_stamp_def.logid(); + bool notify = log_time_stamp_def.notify(); + uint32_t flat = log_time_stamp_def.flat(); + + GELOGD("ProfilingTraceTask execute async start. logid = %lu, notify = %d.", log_id, notify); + rtError_t rt_ret = rtProfilerTrace(log_id, notify, flat, context.GetStream()); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + GELOGD("[%s] ProfilingTraceTask[%lu] execute success.", context.GetNodeName(), log_id); + } + + return SUCCESS; +}; + Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { + GE_CHECK_NOTNULL(node); + auto op_type = node->GetType(); if (op_type == IDENTITY) { task = MakeShared(); } else if (op_type == IDENTITYN) { task = MakeShared(); + } else if (op_type == PROFILINGTRAININGTRACE) { + auto *task_defs = model.GetTaskDefs(node); + if (task_defs == nullptr || task_defs->empty()) { + GELOGE(INTERNAL_ERROR, "Profiling node has no task to execute."); + return INTERNAL_ERROR; + } + task = MakeShared(*task_defs); } else { GELOGE(INTERNAL_ERROR, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), op_type.c_str()); return INTERNAL_ERROR; diff --git a/ge/hybrid/node_executor/rts/rts_node_executor.h b/ge/hybrid/node_executor/rts/rts_node_executor.h index 2576b73b..df487d6c 100644 --- a/ge/hybrid/node_executor/rts/rts_node_executor.h +++ b/ge/hybrid/node_executor/rts/rts_node_executor.h @@ -18,6 +18,7 @@ #define GE_HYBRID_NODE_EXECUTOR_RTS_RTS_NODE_EXECUTOR_H_ #include "hybrid/node_executor/node_executor.h" +#include "proto/task.pb.h" namespace ge { namespace hybrid { @@ -35,6 +36,18 @@ class IdentityNNodeTask : public IdentityNodeTask { Status ExecuteAsync(TaskContext &context, std::function done_callback) override; }; +class ProfilingTraceNodeTask : public NodeTask { + public: + explicit ProfilingTraceNodeTask(const std::vector &task_defs) : task_defs_(task_defs) {} + ~ProfilingTraceNodeTask() override = default; + + Status UpdateArgs(TaskContext &context) override; + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + + private: + std::vector task_defs_; +}; + class RtsNodeExecutor : public NodeExecutor { public: Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const override; diff --git a/ge/hybrid/node_executor/task_context.h b/ge/hybrid/node_executor/task_context.h index 0e85a8e3..8ba4fb90 100644 --- a/ge/hybrid/node_executor/task_context.h +++ b/ge/hybrid/node_executor/task_context.h @@ -123,7 +123,7 @@ class TaskContext { Status status_ = SUCCESS; std::vector workspaces_; uint64_t iteration_ = 0; - uint32_t task_id_= 0; + uint32_t task_id_ = 0; uint32_t stream_id_ = 0; }; } // namespace hybrid diff --git a/inc/framework/common/ge_types.h b/inc/framework/common/ge_types.h index 4267aec4..685e03fd 100644 --- a/inc/framework/common/ge_types.h +++ b/inc/framework/common/ge_types.h @@ -263,6 +263,8 @@ struct ComputeGraphDescInfo { std::vector output_format; std::vector> output_shape; std::vector output_data_type; + uint32_t task_id; + uint32_t stream_id; }; struct OpDescInfo { diff --git a/inc/framework/common/types.h b/inc/framework/common/types.h index 99c2ea03..e3baa816 100644 --- a/inc/framework/common/types.h +++ b/inc/framework/common/types.h @@ -529,6 +529,9 @@ REGISTER_OPTYPE_DECLARE(HVDWAIT, "HorovodWait"); // aicpu op for online_infer dynamic_dims REGISTER_OPTYPE_DECLARE(GETDYNAMICDIMS, "GetDynamicDims"); +// profiling training trace node +REGISTER_OPTYPE_DECLARE(PROFILINGTRAININGTRACE, "ProfilingTrainingTrace"); + enum InputMode { INPUT = 0, CONST_INPUT }; // Definition of the processing status enum of the process module diff --git a/metadef b/metadef index 44bcbb5e..30cf97ba 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 44bcbb5ea25ada1a5393aa4c7f554d40b6859b18 +Subproject commit 30cf97ba0c9a70ade0d9df92695c9dcd671316f6 diff --git a/parser b/parser index 5b93b050..e338bc22 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit 5b93b050dd7ca5b77c3001a790031d877fa10956 +Subproject commit e338bc2200bed9f11f6e665e6ad37a3a97906354 From e73532dfb8dccffc7c3870a12edd649507c8b49e Mon Sep 17 00:00:00 2001 From: zhou_lili Date: Mon, 4 Jan 2021 18:58:51 +0800 Subject: [PATCH 2/2] change switchn to case and add ut --- .../load/new_model_manager/davinci_model.cc | 181 +++--- .../load/new_model_manager/davinci_model.h | 16 +- .../load/new_model_manager/model_manager.cc | 12 +- .../load/new_model_manager/model_manager.h | 6 +- .../task_info/hccl_task_info.cc | 4 +- ge/graph/manager/graph_manager.cc | 6 +- .../common_subexpression_elimination_pass.cc | 6 +- ge/graph/passes/multi_batch_clone_pass.cc | 553 +++++++++++++++--- ge/graph/passes/multi_batch_clone_pass.h | 58 +- ge/graph/passes/unused_args_clean_pass.cc | 4 + ge/graph/preprocess/multi_batch_copy_graph.cc | 12 +- ge/graph/preprocess/multi_batch_options.cc | 5 +- inc/framework/omg/omg_inner_types.h | 3 + metadef | 2 +- parser | 2 +- tests/ut/ge/CMakeLists.txt | 1 + .../ge/graph/load/davinci_model_unittest.cc | 101 ++++ .../passes/multi_batch_clone_pass_unittest.cc | 247 ++++++++ 18 files changed, 1016 insertions(+), 203 deletions(-) create mode 100644 tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc diff --git a/ge/graph/load/new_model_manager/davinci_model.cc b/ge/graph/load/new_model_manager/davinci_model.cc index 0e97f9c0..ad5ee49b 100755 --- a/ge/graph/load/new_model_manager/davinci_model.cc +++ b/ge/graph/load/new_model_manager/davinci_model.cc @@ -87,6 +87,7 @@ const uint32_t kDumpL1FusionOpMByteSize = 2097152; // 2 * 1024 * 1024 const uint32_t kDumpFlagOfL1Fusion = 0; const char *const kDefaultBatchLable = "Batch_default"; const char *const kGetDynamicDimsName = "ascend_mbatch_get_dynamic_dims_node"; +const char *const kMultiBatchNodePostfix = "_ascend_mbatch_batch_"; const int32_t kInvalidStream = -1; const uint32_t kEndOfSequence = 0x0704000a; const uint32_t kEndOfSequenceNew = 507005; @@ -867,6 +868,10 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { GELOGE(PARAM_INVALID, "NetOutput init failed, Name: %s", op_desc->GetName().c_str()); return PARAM_INVALID; } + if (InitRealSizeAndShapeInfo(compute_graph, node) != SUCCESS) { + GELOGE(PARAM_INVALID, "Init real size and shape failed, Name: %s", op_desc->GetName().c_str()); + return PARAM_INVALID; + } continue; } @@ -1143,16 +1148,24 @@ Status DavinciModel::InitNetOutput(const ComputeGraphPtr &graph, const NodePtr & real_virtual_addrs_.insert(real_addr); } } + return SUCCESS; +} +Status DavinciModel::InitRealSizeAndShapeInfo(const ComputeGraphPtr &compute_graph, const NodePtr &node) { + if (node->GetName().find(kMultiBatchNodePostfix) != string::npos) { + GELOGD("No need to get size and shape of netoutput in subgraph."); + return SUCCESS; + } + GELOGD("Start init real size and shape info of %s.", node->GetName().c_str()); GetAllGearsInfo(node); if (is_getnext_sink_dynamic_) { GE_IF_BOOL_EXEC(GetGetDynamicDimsNodeInfo(node) != SUCCESS, GELOGE(PARAM_INVALID, "Failed to get info of getdynamicdims node."); return PARAM_INVALID;); } if (is_online_infer_dynamic_) { - GE_IF_BOOL_EXEC(GetGearAndRealOutSizeInfo(input_count, node) != SUCCESS, + GE_IF_BOOL_EXEC(GetGearAndRealOutSizeInfo(compute_graph, node) != SUCCESS, GELOGE(PARAM_INVALID, "Failed to get gear and real out size info."); return PARAM_INVALID;); - GE_IF_BOOL_EXEC(GetGearAndRealOutShapeInfo(input_count, op_desc) != SUCCESS, + GE_IF_BOOL_EXEC(GetGearAndRealOutShapeInfo(compute_graph, node) != SUCCESS, GELOGE(PARAM_INVALID, "Failed to get gear and real out shape info."); return PARAM_INVALID;); } @@ -1171,7 +1184,7 @@ void DavinciModel::GetAllGearsInfo(const NodePtr &node) { if (shape_str.empty()) { continue; } - std::vector gear_info; + std::vector gear_info; std::vector dims = ge::StringUtils::Split(shape_str, ','); for (const auto &dim : dims) { if (dim.empty()) { @@ -1187,6 +1200,7 @@ void DavinciModel::GetAllGearsInfo(const NodePtr &node) { } } } + Status DavinciModel::GetGetDynamicDimsNodeInfo(const NodePtr &node) { GE_CHECK_NOTNULL(node->GetOpDesc()); size_t input_count = node->GetAllInDataAnchors().size(); @@ -1224,11 +1238,11 @@ Status DavinciModel::GetGetDynamicDimsNodeInfo(const NodePtr &node) { return SUCCESS; } -Status DavinciModel::GetGearAndRealOutSizeInfo(size_t input_count, const NodePtr &node) { - GELOGD("Start get gear and real output size info of %s, input count is %zu.", node->GetName().c_str(), input_count); +Status DavinciModel::GetGearAndRealOutSizeInfo(const ComputeGraphPtr &graph, const NodePtr &node) { + GELOGD("Start get gear and real output size info of %s.", node->GetName().c_str()); merge_nodes_gear_and_real_out_size_info_.clear(); - for (size_t idx = 0; idx < input_count; ++idx) { - auto in_anchor = node->GetAllInDataAnchors().at(idx); + size_t idx = 0; + for (const auto &in_anchor : node->GetAllInDataAnchors()) { auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { continue; @@ -1236,89 +1250,106 @@ Status DavinciModel::GetGearAndRealOutSizeInfo(size_t input_count, const NodePtr auto peer_node = peer_out_anchor->GetOwnerNode(); auto op_desc = peer_node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - if ((peer_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { - if (GetRealOutputSizeOfMerge(idx, peer_node) != SUCCESS) { + if ((peer_node->GetType() == CASE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { + if (GetRealOutputSizeOfCase(graph, idx, peer_node) != SUCCESS) { GELOGE(PARAM_INVALID, "Get real output size of %s failed.", peer_node->GetName().c_str()); return PARAM_INVALID; } } + idx++; } return SUCCESS; } -Status DavinciModel::GetRealOutputSizeOfMerge(size_t input_index, const NodePtr &merge_node) { - GELOGD("Start get output size of %s, which is %zu input to netoutput.", merge_node->GetName().c_str(), input_index); - std::map, int64_t> gear_and_real_out_size_info; - for (auto &in_anchor : merge_node->GetAllInDataAnchors()) { - auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - continue; - } - auto in_node = peer_out_anchor->GetOwnerNode(); - GELOGD("Input node of merge is %s.", in_node->GetName().c_str()); - auto op_desc = in_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - string batch_label; - if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { - size_t batch_index = static_cast(stoi(batch_label.substr(batch_label.rfind('_') + 1))); - GELOGD("Batch index of %s is %zu.", op_desc->GetName().c_str(), batch_index); - if (batch_index > all_gears_info_.size()) { - GELOGE(PARAM_INVALID, "The value of ATTR_NAME_BATCH_LABEL is invalid."); - return PARAM_INVALID; - } - - const vector output_size_list = ModelUtils::GetOutputSize(op_desc); - int output_index = ge::AnchorUtils::GetIdx(peer_out_anchor); - auto tensor_desc = op_desc->GetOutputDescPtr(output_index); - GE_CHECK_NOTNULL(tensor_desc); - int64_t data_size = 0; - if (TensorUtils::GetTensorSizeInBytes(*tensor_desc, data_size) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Get tensor size in bytes failed."); - return FAILED; +Status DavinciModel::GetRealOutputSizeOfCase(const ComputeGraphPtr &graph, size_t input_index, + const NodePtr &case_node) { + GELOGD("Start get output size of %s, which is %zu input to netoutput.", case_node->GetName().c_str(), input_index); + const auto &func_desc = case_node->GetOpDesc(); + GE_CHECK_NOTNULL(func_desc); + std::map, int64_t> gear_and_real_out_size_info; + for (const auto &name : func_desc->GetSubgraphInstanceNames()) { + const auto &subgraph = graph->GetSubgraph(name); + if (subgraph == nullptr) { + GELOGE(GE_GRAPH_EMPTY_SUBGRAPH, "Subgraph not found, name: %s.", name.c_str()); + return GE_GRAPH_EMPTY_SUBGRAPH; + } + for (auto &node : subgraph->GetDirectNode()) { + if (node->GetType() == NETOUTPUT) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + string batch_label; + if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { + size_t batch_index = static_cast(stoi(batch_label.substr(batch_label.rfind('_') + 1))); + GELOGD("Batch index of %s is %zu.", op_desc->GetName().c_str(), batch_index); + if (batch_index > all_gears_info_.size()) { + GELOGE(PARAM_INVALID, "The value of ATTR_NAME_BATCH_LABEL is invalid."); + return PARAM_INVALID; + } + + const vector input_size_list = ModelUtils::GetInputSize(op_desc); + auto tensor_desc = op_desc->GetInputDescPtr(input_index); + GE_CHECK_NOTNULL(tensor_desc); + int64_t data_size = 0; + if (TensorUtils::GetTensorSizeInBytes(*tensor_desc, data_size) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Get tensor size in bytes failed."); + return FAILED; + } + gear_and_real_out_size_info[all_gears_info_[batch_index]] = data_size; + GELOGD("Get real gear index is: %zu, gear info is %s, size is %ld, tensor size is %ld", + batch_index, formats::JoinToString(all_gears_info_[batch_index]).c_str(), + input_size_list[input_index], data_size); + } + break; } - gear_and_real_out_size_info[all_gears_info_[batch_index]] = data_size; - GELOGD("Get real gear index is: %zu, gear info is %s, size is %ld, tensor size is %ld", - batch_index, formats::JoinToString(all_gears_info_[batch_index]).c_str(), - output_size_list[output_index], data_size); } } merge_nodes_gear_and_real_out_size_info_[input_index] = gear_and_real_out_size_info; return SUCCESS; } -Status DavinciModel::GetGearAndRealOutShapeInfo(size_t input_count, const OpDescPtr &op_desc) { - GELOGD("Start to get dynamic output dims of %s.", op_desc->GetName().c_str()); +Status DavinciModel::GetGearAndRealOutShapeInfo(const ComputeGraphPtr &graph, const NodePtr &node) { + GELOGD("Start to get dynamic output dims of %s.", node->GetName().c_str()); merge_nodes_gear_and_real_out_shape_info_.clear(); - std::vector dynamic_output_shape_info; - if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_shape_info)) { - GELOGD("Can not get dynamic output dims attr"); - return SUCCESS; - } - GELOGI("Dynamic output shape info is %s", formats::JoinToString(dynamic_output_shape_info).c_str()); - std::vector> dynamic_output_shape; - ParseDynamicOutShape(dynamic_output_shape_info, dynamic_output_shape); - // idx: input_index to netoutput - for (size_t idx = 0; idx < input_count; ++idx) { - std::map, vector> gear_and_real_out_shape_info; - for (auto &it : dynamic_output_shape) { - auto gear_index = static_cast(it[0]); - if (gear_index > all_gears_info_.size()) { - GELOGE(PARAM_INVALID, "The value of cur index: %zu is invalid.", static_cast(it[0])); - return PARAM_INVALID; + size_t idx = 0; + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + continue; + } + auto peer_node = peer_out_anchor->GetOwnerNode(); + auto op_desc = peer_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if ((peer_node->GetType() == CASE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { + std::vector dynamic_output_shape_info; + if (!AttrUtils::GetListStr(node->GetOpDesc(), ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_shape_info)) { + GELOGD("Can not get dynamic output dims attr from %s.", node->GetName().c_str()); + return SUCCESS; } + GELOGI("Dynamic output shape info is %s", formats::JoinToString(dynamic_output_shape_info).c_str()); + std::vector> dynamic_output_shape; + ParseDynamicOutShape(dynamic_output_shape_info, dynamic_output_shape); + std::map, vector> gear_and_real_out_shape_info; + for (auto &it : dynamic_output_shape) { + auto gear_index = static_cast(it[0]); + if (gear_index > all_gears_info_.size()) { + GELOGE(PARAM_INVALID, "The value of cur index: %zu is invalid.", static_cast(it[0])); + return PARAM_INVALID; + } - if (static_cast(it[1]) == idx) { - vector output_shape; - for (size_t i = 2; i < it.size(); ++i) { - output_shape.emplace_back(it[i]); + if (static_cast(it[1]) == idx) { + vector output_shape; + for (size_t i = 2; i < it.size(); ++i) { + output_shape.emplace_back(it[i]); + } + gear_and_real_out_shape_info[all_gears_info_[gear_index]] = output_shape; + GELOGD("Get real gear index is: %zu, gear info is %s, output shape is %s.", + gear_index, formats::JoinToString(all_gears_info_[gear_index]).c_str(), + formats::JoinToString(output_shape).c_str()); } - gear_and_real_out_shape_info[all_gears_info_[gear_index]] = output_shape; - GELOGD("Get real gear index is: %zu, gear info is %s, output shape is %s.", - gear_index, formats::JoinToString(all_gears_info_[gear_index]).c_str(), - formats::JoinToString(output_shape).c_str()); } + merge_nodes_gear_and_real_out_shape_info_[idx] = gear_and_real_out_shape_info; } - merge_nodes_gear_and_real_out_shape_info_[idx] = gear_and_real_out_shape_info; + idx++; } return SUCCESS; } @@ -1962,7 +1993,7 @@ void DavinciModel::CreateOutput(uint32_t index, const OpDescPtr &op_desc, InputO uint32_t &format_result) { /// netoutput input tensor desc GE_IF_BOOL_EXEC(op_desc->GetInputDescPtr(index) == nullptr, GELOGE(FAILED, "OpDesc GetInputDescPtr is nullptr"); - return ); + return); Format format = op_desc->GetInputDescPtr(index)->GetFormat(); GeShape shape = op_desc->GetInputDescPtr(index)->GetShape(); DataType data_type = op_desc->GetInputDescPtr(index)->GetDataType(); @@ -2567,7 +2598,7 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b GELOGD("Reinit cur dynamic dims when getnext sink dynamic."); cur_dynamic_dims_.clear(); cur_dynamic_dims_.resize(shape_of_cur_dynamic_dims_); - auto ret = rtMemcpy(cur_dynamic_dims_.data(), shape_of_cur_dynamic_dims_ * sizeof(int64_t), + auto ret = rtMemcpy(cur_dynamic_dims_.data(), shape_of_cur_dynamic_dims_ * sizeof(int32_t), netoutput_last_input_addr_, netoutput_last_input_size_, RT_MEMCPY_DEVICE_TO_HOST); GE_CHK_RT_RET(ret); } @@ -2668,11 +2699,11 @@ void *DavinciModel::Run(DavinciModel *model) { GE_IF_BOOL_EXEC(current_data.blobs.empty(), break); auto shape_data_buffer_data = current_data.blobs.back().data; auto shape_data_buffer_length = current_data.blobs.back().length; - model->cur_dynamic_dims_.assign(reinterpret_cast(shape_data_buffer_data), - reinterpret_cast(shape_data_buffer_data) + - shape_data_buffer_length / sizeof(int64_t)); + model->cur_dynamic_dims_.assign(reinterpret_cast(shape_data_buffer_data), + reinterpret_cast(shape_data_buffer_data) + + shape_data_buffer_length / sizeof(int32_t)); GELOGD("Data: cur dynamic dims is %s", formats::JoinToString(model->cur_dynamic_dims_).c_str()); - delete[] reinterpret_cast(current_data.blobs.back().data); + delete[] reinterpret_cast(current_data.blobs.back().data); current_data.blobs.pop_back(); } GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_PRE_PROC_END)); diff --git a/ge/graph/load/new_model_manager/davinci_model.h b/ge/graph/load/new_model_manager/davinci_model.h index cb3902df..893dfc2a 100755 --- a/ge/graph/load/new_model_manager/davinci_model.h +++ b/ge/graph/load/new_model_manager/davinci_model.h @@ -864,11 +864,13 @@ class DavinciModel { void ParseDynamicOutShape(const vector &str_info, vector> &vec_info); bool IsGetNextSinkDynamic(const OpDescPtr &op_desc); + + Status InitRealSizeAndShapeInfo(const ComputeGraphPtr &compute_graph, const NodePtr &node); void GetAllGearsInfo(const NodePtr &node); Status GetGetDynamicDimsNodeInfo(const NodePtr &node); - Status GetGearAndRealOutSizeInfo(size_t input_count, const NodePtr &node); - Status GetRealOutputSizeOfMerge(size_t input_index, const NodePtr &merge_node); - Status GetGearAndRealOutShapeInfo(size_t input_count, const OpDescPtr &op_desc); + Status GetGearAndRealOutSizeInfo(const ComputeGraphPtr &graph, const NodePtr &node); + Status GetRealOutputSizeOfCase(const ComputeGraphPtr &graph, size_t input_index, const NodePtr &case_node); + Status GetGearAndRealOutShapeInfo(const ComputeGraphPtr &graph, const NodePtr &node); bool is_weight_mem_has_inited_; bool is_feature_map_mem_has_inited_; @@ -1023,15 +1025,15 @@ class DavinciModel { bool is_new_model_desc_{false}; bool is_online_infer_dynamic_ = false; bool is_getnext_sink_dynamic_ = false; - vector cur_dynamic_dims_; + vector cur_dynamic_dims_; void *netoutput_last_input_addr_ = nullptr; int64_t netoutput_last_input_size_ = 0; size_t shape_of_cur_dynamic_dims_ = 0; // key: input_index: input is merge node; value: each gear info and each output size - map, int64_t>> merge_nodes_gear_and_real_out_size_info_; + map, int64_t>> merge_nodes_gear_and_real_out_size_info_; // key: input_index: input is merge node; value: each gear info and each output shape - map, vector>> merge_nodes_gear_and_real_out_shape_info_; - vector> all_gears_info_; + map, vector>> merge_nodes_gear_and_real_out_shape_info_; + vector> all_gears_info_; multimap op_id_map_; vector profile_list_; diff --git a/ge/graph/load/new_model_manager/model_manager.cc b/ge/graph/load/new_model_manager/model_manager.cc index 6f923236..b2cce73a 100755 --- a/ge/graph/load/new_model_manager/model_manager.cc +++ b/ge/graph/load/new_model_manager/model_manager.cc @@ -460,8 +460,8 @@ Status ModelManager::DataInput(const InputData &input_data, OutputData &output_d Status ModelManager::GetCurDynamicDims(const vector> &user_real_input_dims, const vector>> &user_input_dims, - vector &cur_dynamic_dims) { - GELOGD(" Start get cur dynamic dims."); + vector &cur_dynamic_dims) { + GELOGD("Start get cur dynamic dims."); if (user_real_input_dims.size() != user_input_dims.size()) { GELOGE(INTERNAL_ERROR, "The input count of user: %zu should be equal to the data count of graph: %zu", @@ -478,7 +478,7 @@ Status ModelManager::GetCurDynamicDims(const vector> &user_real_ } for (size_t j = 0; j < user_input_dims.at(i).second.size(); ++j) { if (user_input_dims.at(i).second.at(j) < 0) { - cur_dynamic_dims.emplace_back(user_real_input_dims[i][j]); + cur_dynamic_dims.emplace_back(static_cast(user_real_input_dims[i][j])); } } } @@ -523,7 +523,7 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector cur_dynamic_dims; + std::vector cur_dynamic_dims; if (!GetLocalOmgContext().user_real_input_dims.empty()) { if (GetCurDynamicDims(GetLocalOmgContext().user_real_input_dims, GetLocalOmgContext().user_input_dims, cur_dynamic_dims) != SUCCESS) { @@ -531,9 +531,9 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector(cur_dynamic_dims.size() * sizeof(int64_t)); + uint32_t length = static_cast(cur_dynamic_dims.size() * sizeof(int32_t)); GE_CHK_BOOL_EXEC(memcpy_s(data.data, length, cur_dynamic_dims.data(), length) == EOK, return INTERNAL_ERROR, "Failed to memcpy data."); data.length = length; diff --git a/ge/graph/load/new_model_manager/model_manager.h b/ge/graph/load/new_model_manager/model_manager.h index 088ea5fd..500cad31 100755 --- a/ge/graph/load/new_model_manager/model_manager.h +++ b/ge/graph/load/new_model_manager/model_manager.h @@ -126,14 +126,14 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { /// /// @ingroup domi_ome /// @brief Get cur_dynamic_dims for all input. - /// @param [in] vector> &user_real_input_dims: dims info of all user_inputs. + /// @param [in] vector> &user_real_input_dims: dims info of all user_inputs. /// @param [in] vector>> &user_input_dims: key:name. value:dynamic dims from option. - /// @param [out] vector &cur_dynamic_dims: real dims gather, where the index of -1. + /// @param [out] vector &cur_dynamic_dims: real dims gather, where the index of -1. /// @return 0: SUCCESS / others: INTERNAL_ERROR /// Status GetCurDynamicDims(const vector> &user_real_input_dims, const vector>> &user_input_dims, - vector &cur_dynamic_dims); + vector &cur_dynamic_dims); /// /// @ingroup domi_ome diff --git a/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc b/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc index df43fd5b..8033c93e 100644 --- a/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc +++ b/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc @@ -145,7 +145,9 @@ Status HcclTaskInfo::SetFollowStream(const ge::ConstOpDescPtr &op_desc, DavinciM } else { GELOGI("need to reuse follow stream and create new follow stream."); size_t created_stream_num = follow_stream_usage.size(); - hccl_stream_list_ = follow_stream_usage; + for (const auto &stream : follow_stream_usage) { + hccl_stream_list_.emplace_back(stream); + } ret = CreateStream(hccl_stream_num - created_stream_num, davinci_model, main_stream_id); if (ret != SUCCESS) { GELOGE(RT_FAILED, "Create hccl stream failed."); diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 6372a018..38de6ff7 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -2780,8 +2780,10 @@ Status GraphManager::ParseInputsDims(const std::vector &input_t if (!GetLocalOmgContext().dynamic_node_type.empty()) { vector data_nodes; vector getnext_nosink_nodes; - data_nodes = compute_graph_->TryGetExtAttr(kExtAttrDataNodes, data_nodes); - getnext_nosink_nodes = compute_graph_->TryGetExtAttr(kExtAttrGetNextNoSink, getnext_nosink_nodes); + data_nodes = GetLocalOmgContext().data_nodes; + getnext_nosink_nodes = GetLocalOmgContext().getnext_nosink_nodes; + GELOGD("Data nodes count is %zu, getnext nosink nodes count is %zu.", data_nodes.size(), + getnext_nosink_nodes.size()); if (GetLocalOmgContext().dynamic_node_type == DATA) { if (getnext_nosink_nodes.empty()) { // just data or data+getnext_sink diff --git a/ge/graph/passes/common_subexpression_elimination_pass.cc b/ge/graph/passes/common_subexpression_elimination_pass.cc index a4662d5d..7d9724fc 100644 --- a/ge/graph/passes/common_subexpression_elimination_pass.cc +++ b/ge/graph/passes/common_subexpression_elimination_pass.cc @@ -26,6 +26,10 @@ namespace ge { namespace { +std::set un_compute_attrs = { + {ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES}, +}; + std::string GetCseKey(const NodePtr &node) { std::stringstream ss; ss << node->GetType() << "-data-inputs-"; @@ -49,7 +53,7 @@ std::string GetCseKey(const NodePtr &node) { ss << name << "-"; } - ss << "attrs-" << AttrUtils::GetAllAttrsStr(node->GetOpDesc()); + ss << "attrs-" << AttrUtils::GetAttrsStrAfterRid(node->GetOpDesc(), un_compute_attrs); return ss.str(); } diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index f8451ace..b7efa070 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -25,31 +25,65 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "register/op_registry.h" +#include "graph/common/omg_util.h" namespace ge { namespace { constexpr uint8_t kDataInIndex = 0; constexpr uint8_t kDataOutIndex = 0; constexpr uint8_t kCaseArgIndex = 1; +const int kDivisionConst = 2; +const size_t kNumOfGetnextNode = 1; const std::string kMultiBatchCaseNode = "ascend_mbatch_shape_case"; const std::string kMultiBatchDataNode = "ascend_mbatch_shape_data"; +const std::string kMultiBatchGetDynamicDimsNode = "ascend_mbatch_get_dynamic_dims_node"; const std::string kMultiBatchConstNode = "ascend_mbatch_shape_const"; const std::string kMultiBatchMapIndexNode = "ascend_mbatch_shape_mapindex"; const std::string kMultiBatchNodePostfix = "_ascend_mbatch_batch_"; +const char *const kGetNextName = "IteratorV2"; } // namespace +inline bool IsGetNextType(const NodePtr &node) { + std::string original_type; + GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, + GELOGW("Get original type failed."); return false); + return (original_type == kGetNextName); +} + Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { + GE_IF_BOOL_EXEC(graph == nullptr, GELOGE(FAILED, "Original graph is nullptr"); return FAILED); if (graph->GetParentGraph() != nullptr) { GELOGD("Subgraph %s skip the MultiBatchClonePass", graph->GetName().c_str()); return SUCCESS; } - + if (!GetLocalOmgContext().need_multi_batch) { + GELOGI("No need to process_multi for no_train graph."); + return SUCCESS; + } + std::vector data_nodes; + std::vector getnext_nosink_nodes; + std::vector getnext_sink_nodes; + if (multibatch::CheckSequenceOfOptions(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) { + GELOGE(PARAM_INVALID, "[Train_Dynamic] CheckSequenceOfOptions failed."); + return PARAM_INVALID; + } + if (multibatch::UpdateNameOfInputShape(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) { + GELOGE(PARAM_INVALID, "[Train_Dynamic] UpdateNameForInputShapeOfOption failed."); + return PARAM_INVALID; + } + if (multibatch::DeleteIdentityInsertByAdapter(graph) != SUCCESS) { + GELOGE(PARAM_INVALID, "[Train_Dynamic] DeleteIdentityInsertByAdapter failed."); + return PARAM_INVALID; + } if (!multibatch::InitDynamicParams(batch_shapes_)) { GELOGD("There is no multi-batch options, no need clone multi-batch graph"); return SUCCESS; } - + if (multibatch::CheckNegativeCountOfOptions(batch_shapes_) != SUCCESS) { + GELOGE(PARAM_INVALID, "[Train_Dynamic] Input_shape and dynamic_dims should set correct params."); + return PARAM_INVALID; + } GELOGD("Begin to run Multi-batch clone on graph: %s", graph->GetName().c_str()); GE_CHK_STATUS_RET(multibatch::CheckDynamicParams(batch_shapes_), "Invalid multi-batch param"); if (CollectIoNodes(graph) != SUCCESS) { @@ -66,21 +100,14 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { (void)AttrUtils::GetStr(graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_); ComputeGraphPtr branch = MakeShared(graph->GetName()); - if (branch == nullptr) { - GELOGE(OUT_OF_MEMORY, "Create multi-batch graph failed"); - return OUT_OF_MEMORY; - } + GE_IF_BOOL_EXEC(branch == nullptr, GELOGE(OUT_OF_MEMORY, "Create multi batch graph failed"); return OUT_OF_MEMORY); (void)AttrUtils::SetStr(branch, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_); graph->InValid(); // Will modify, need topological again. graph->Swap(*branch); - if (CreateRootGraph(graph) != SUCCESS) { - return FAILED; - } - - if (CreateSubgraphs(graph, branch) != SUCCESS) { - return FAILED; - } + GE_CHK_STATUS_RET(CreateRootGraph(graph), "Construct root graph failed."); + GE_CHK_STATUS_RET(CreateOriGraph(branch), "Construct original graph failed.") + GE_CHK_STATUS_RET(CreateSubgraphs(graph, branch), "Construct subgraph failed."); GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed"); GELOGD("MultiBatchClonePass Leave"); @@ -95,9 +122,13 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { /// Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { for (const auto &node : graph->GetDirectNode()) { + if (!GetLocalOmgContext().dynamic_node_type.empty() && IsGetNextType(node)) { + all_data_nodes_.emplace_back(node); + GE_CHK_STATUS_RET(InitParamsOfGetNext(node), "Init params of %s failed.", node->GetName().c_str()); + } if (node->GetType() == DATA) { all_data_nodes_.emplace_back(node); - } else if (node->GetType() == CONSTANT) { + } else if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) { all_const_nodes_.emplace_back(node); } else if (node->GetType() == NETOUTPUT) { all_output_nodes_.emplace_back(node); @@ -114,10 +145,16 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { } int64_t data_index = 0; + size_t getnext_node_count = 0; for (size_t i = 0; i < all_data_nodes_.size(); ++i) { + if (IsGetNextType(all_data_nodes_[i])) { + // just one getnext node in graph + getnext_node_count++; + continue; + } const auto &op_desc = all_data_nodes_[i]->GetOpDesc(); if (!AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) { - (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i); + (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i - getnext_node_count); } } @@ -133,7 +170,43 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { "Remove edge failed"); } } + GELOGD("Data count is %zu, const count is %zu, getnext count is %zu, output count is %zu, direct out count is %zu.", + all_data_nodes_.size(), all_const_nodes_.size(), getnext_node_count, all_output_nodes_.size(), + direct_output_.size()); + + return SUCCESS; +} +Status MultiBatchClonePass::InitParamsOfGetNext(const NodePtr &node) { + data_count_from_getnext_ = 0; + getnext_sink_dynamic_dims_ = false; + GE_CHECK_NOTNULL(node->GetOpDesc()); + data_count_from_getnext_ = node->GetOpDesc()->GetOutputsSize(); + if (GetLocalOmgContext().dynamic_node_type == GETNEXT) { + data_count_from_getnext_ = data_count_from_getnext_ / kDivisionConst; + for (size_t i = 0; i < data_count_from_getnext_; ++i) { + GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(i); + GELOGD("The %zu data shape from getnext sink is %s.", i, + formats::JoinToString(output_desc.GetShape().GetDims()).c_str()); + const auto &dims = output_desc.GetShape().GetDims(); + if (std::all_of(dims.begin(), dims.end(), [](int64_t val) {return val >= 0; })) { + GELOGD("The %zu data from %s is static.", i, node->GetName().c_str()); + } else { + getnext_sink_dynamic_dims_ = true; + GELOGD("Dynamic dims in the pattern of getnext sink."); + } + } + } + if (node->GetOutControlAnchor() != nullptr) { + for (const auto &peer_in_control_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) { + NodePtr next_node = peer_in_control_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(next_node); + if (next_node->GetType() == CONSTANTOP) { + out_control_nodes_.insert(next_node); + GELOGD("Control edge: %s connect with %s.", node->GetName().c_str(), next_node->GetName().c_str()); + } + } + } return SUCCESS; } @@ -144,7 +217,11 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { /// @return 0: SUCCESS / others: FAILED /// Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) { + GELOGD("Start create root graph of %s.", graph->GetName().c_str()); uint32_t input_num = all_data_nodes_.size() + all_const_nodes_.size(); + if (data_count_from_getnext_ != 0) { + input_num = input_num + data_count_from_getnext_ - kNumOfGetnextNode; + } uint32_t output_num = all_output_nodes_[0]->GetAllInDataAnchorsSize(); OpDescBuilder op_builder(kMultiBatchCaseNode, CASE); @@ -185,6 +262,10 @@ Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) { op_desc->GetName().c_str()); return FAILED; } + if (!AttrUtils::SetBool(op_desc, ATTR_INSERT_BY_MBATCH, true)) { + GELOGE(INTERNAL_ERROR, "Failed to add insert attr on case node %s", op_desc->GetName().c_str()); + return INTERNAL_ERROR; + } GE_CHK_STATUS_RET(multibatch::StampDynamicType(op_desc), "Set dynamic type failed"); GE_CHK_STATUS_RET(CreateIndexNode(graph), "Create index node failed"); @@ -202,7 +283,7 @@ Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) { /// @param [in] NodePtr node: index data node. /// @return 0: SUCCESS / others: FAILED /// -Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &node) { +Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &shape_node) { const OpDescPtr data_desc = MakeShared(kMultiBatchDataNode, DATA); if (data_desc == nullptr) { GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed"); @@ -220,11 +301,12 @@ Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, No } size_t data_index = all_data_nodes_.size(); + data_index = data_count_from_getnext_ != 0 ? data_index - kNumOfGetnextNode : data_index; (void)AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, data_index); (void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true); - node = graph->AddNode(data_desc); - if (node == nullptr) { + shape_node = graph->AddNode(data_desc); + if (shape_node == nullptr) { GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed"); return OUT_OF_MEMORY; } @@ -286,15 +368,19 @@ Status MultiBatchClonePass::CreateIndexConstNode(const ComputeGraphPtr &graph, N /// @return 0: SUCCESS / others: FAILED /// Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { - // Data --> MapIndex --> Case - NodePtr data_node; - GE_CHK_STATUS_RET(CreateIndexDataNode(graph, data_node), "Create data node failed"); + // Data/GetDynamicDims --> MapIndex --> Case + if (!getnext_sink_dynamic_dims_) { + GE_CHK_STATUS_RET(CreateIndexDataNode(graph, shape_node_), "Create data node failed"); + } else { + GE_CHK_STATUS_RET(CreateGetDynamicDimsNode(graph, shape_node_), "Create get dynamic dims node failed"); + } NodePtr const_node; GE_CHK_STATUS_RET(CreateIndexConstNode(graph, const_node), "Create const node failed"); - + GELOGD("Shape node name is %s, type is %s, const node name is %s.", shape_node_->GetName().c_str(), + shape_node_->GetType().c_str(), const_node->GetName().c_str()); OpDescBuilder op_builder(kMultiBatchMapIndexNode, "MapIndex"); - op_builder.AddInput("x", data_node->GetOpDesc()->GetOutputDesc(0)) + op_builder.AddInput("x", shape_node_->GetOpDesc()->GetOutputDesc(0)) .AddInput("data_seq", const_node->GetOpDesc()->GetOutputDesc(0)) .AddOutput("y", GeTensorDesc(GeShape(), FORMAT_ND, DT_INT32)); @@ -309,8 +395,10 @@ Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { return OUT_OF_MEMORY; } - if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), index_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", data_node->GetName().c_str(), + GE_CHK_STATUS_RET(AddAttrForGetDynamicDims(shape_node_), "Failed to add attr for %s.", + shape_node_->GetName().c_str()); + if (GraphUtils::AddEdge(shape_node_->GetOutDataAnchor(0), index_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", shape_node_->GetName().c_str(), index_node->GetName().c_str()); return FAILED; } @@ -328,6 +416,120 @@ Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { return SUCCESS; } +Status MultiBatchClonePass::CreateGetDynamicDimsNode(const ComputeGraphPtr &graph, NodePtr &shape_node) { + const OpDescPtr data_desc = MakeShared(kMultiBatchGetDynamicDimsNode, GETDYNAMICDIMS); + if (data_desc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create multi-batch get dynamic dims node failed"); + return OUT_OF_MEMORY; + } + + // input of GetDynamicDims is shape_of_each_data, output is gear_info + for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) { + size_t input_shape_dims = GetLocalOmgContext().user_input_dims.at(i).second.size(); + // add input desc without GeShape for const input, value of input_shape is 1 transferred by adapter + if (input_shape_dims == 1 && GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) { + GeTensorDesc tensor_desc; + tensor_desc.SetFormat(FORMAT_ND); + tensor_desc.SetDataType(DT_INT32); + auto ret = data_desc->AddInputDesc(tensor_desc); + GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data"); + return FAILED); + continue; + } + GeTensorDesc tensor_desc(GeShape({static_cast(input_shape_dims)}), FORMAT_ND, DT_INT32); + auto ret = data_desc->AddInputDesc(tensor_desc); + GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data"); + return FAILED); + } + GeTensorDesc tensor_desc(GeShape({static_cast(batch_shapes_.at(0).size())}), FORMAT_ND, DT_INT32); + auto ret = data_desc->AddOutputDesc(tensor_desc); + GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add output desc for created data"); + return FAILED); + + (void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true); + + shape_node = graph->AddNode(data_desc); + if (shape_node == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create multi-batch dynamic dims node failed"); + return OUT_OF_MEMORY; + } + return SUCCESS; +} + +Status MultiBatchClonePass::AddAttrForGetDynamicDims(const NodePtr &shape_node) { + if (!getnext_sink_dynamic_dims_) { + GELOGD("No need to add attr when not insert get dynamic dims node."); + return SUCCESS; + } + GELOGD("Add attr for :%s, type is %s:", shape_node->GetName().c_str(), shape_node->GetType().c_str()); + if (!AttrUtils::SetInt(shape_node->GetOpDesc(), ATTR_GETNEXT_SINK_DATA_COUNT, data_count_from_getnext_)) { + GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_DATA_COUNT failed"); + return INTERNAL_ERROR; + } + vector shape_info; + for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) { + if (GetLocalOmgContext().user_input_dims.at(i).second.size() == 1 && + GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) { + shape_info.emplace_back(0); + continue; + } + shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.size()); + for (size_t j = 0; j < GetLocalOmgContext().user_input_dims.at(i).second.size(); ++j) { + shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.at(j)); + } + } + if (!AttrUtils::SetListInt(shape_node->GetOpDesc(), ATTR_GETNEXT_SINK_SHAPE_INFO, shape_info)) { + GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_SHAPE_INFO failed"); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status MultiBatchClonePass::LinkGetNextToGetDynamicDims(const NodePtr &getnext_node, const NodePtr &shape_node) { + GELOGD("Start relink shape anchor of %s to %s.", getnext_node->GetName().c_str(), shape_node->GetName().c_str()); + size_t input_index = 0; + size_t data_count = getnext_node->GetAllOutDataAnchors().size() / kDivisionConst; + for (size_t out_index = data_count; out_index < getnext_node->GetAllOutDataAnchors().size(); ++out_index, + ++input_index) { + GELOGD("Start add %s of %zu out_anchor to %s of %zu in_anchor.", getnext_node->GetName().c_str(), out_index, + shape_node->GetName().c_str(), input_index); + auto out_data_anchor = getnext_node->GetOutDataAnchor(out_index); + auto ret = GraphUtils::AddEdge(out_data_anchor, shape_node->GetInDataAnchor(input_index)); + GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link getnext %s to getdynamicdims %s", + getnext_node->GetName().c_str(), shape_node->GetName().c_str()); + return INTERNAL_ERROR); + } + return SUCCESS; +} + +Status MultiBatchClonePass::LinkGetDynamicDimsToNetOutput(const NodePtr &output_node) { + if (!GetLocalOmgContext().dynamic_node_type.empty()) { + if (!AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, GetLocalOmgContext().dynamic_dims)) { + GELOGE(INTERNAL_ERROR, "Failed to set all gears info attr on netoutput %s.", output_node->GetName().c_str()); + return INTERNAL_ERROR; + } + } + if (getnext_sink_dynamic_dims_) { + GELOGD("Start link %s to %s.", shape_node_->GetName().c_str(), output_node->GetName().c_str()); + size_t input_index = output_node->GetAllInDataAnchors().size(); + if (NodeUtils::AppendInputAnchor(output_node, input_index + 1) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Append input anchor of %s of %zu failed.", output_node->GetName().c_str(), input_index); + return INTERNAL_ERROR; + } + auto ret = GraphUtils::AddEdge(shape_node_->GetOutDataAnchor(kDataOutIndex), + output_node->GetInDataAnchor(input_index)); + GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link netoutput %s to getdynamicdims %s", + output_node->GetName().c_str(), shape_node_->GetName().c_str()); + return INTERNAL_ERROR); + if (!AttrUtils::SetBool(output_node->GetOpDesc(), ATTR_GETNEXT_SINK_DYNMAIC, true)) { + GELOGE(INTERNAL_ERROR, "Failed to set getnext sink dynamic attr on netoutput %s.", + output_node->GetName().c_str()); + return INTERNAL_ERROR; + } + } + return SUCCESS; +} + /// /// @ingroup ge /// @brief Create input node for root graph. @@ -337,8 +539,10 @@ Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) { // Data --> Case std::vector all_data_nodes; - const size_t arg_index = kCaseArgIndex; - for (size_t i = 0; i < all_data_nodes_.size(); ++i) { + size_t case_input_index = kCaseArgIndex; + NodePtr getnext_node = nullptr; + size_t input_index_of_getnext = 0; + for (size_t i = 0; i < all_data_nodes_.size(); ++i, ++case_input_index) { const auto &node = all_data_nodes_[i]; const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc()); if (op_desc == nullptr) { @@ -353,22 +557,60 @@ Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) { op_desc->SetName(node->GetName()); const NodePtr &data = graph->AddNode(op_desc); GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str()); - if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s", - data->GetName().c_str(), case_node_->GetName().c_str()); - return FAILED; + if (IsGetNextType(node)) { + getnext_node = data; + input_index_of_getnext = case_input_index; + case_input_index = case_input_index + data_count_from_getnext_; + continue; + } else { + if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(case_input_index)) != + GRAPH_SUCCESS) { + GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s", data->GetName().c_str(), + case_node_->GetName().c_str()); + return FAILED; + } } - if (SetMaxShapeToData(data) != SUCCESS) { + if (SetMaxShape(data) != SUCCESS) { + GELOGE(FAILED, "Set max shape of %s failed.", data->GetName().c_str()); return FAILED; } all_data_nodes.emplace_back(data); } + if (getnext_node != nullptr) { + if (LinkEdgeForGetNext(getnext_node, input_index_of_getnext) != SUCCESS) { + GELOGE(FAILED, "Failed to link edge for %s.", getnext_node->GetName().c_str()); + return FAILED; + } + if (SetMaxShape(getnext_node) != SUCCESS) { + GELOGE(FAILED, "Set max shape of %s failed.", getnext_node->GetName().c_str()); + return FAILED; + } + all_data_nodes.emplace_back(getnext_node); + } all_data_nodes_.swap(all_data_nodes); return SUCCESS; } +Status MultiBatchClonePass::LinkEdgeForGetNext(const NodePtr &getnext_node, size_t &case_input_index) { + GELOGD("Start link edge for %s, which is the %zu input of %s.", getnext_node->GetName().c_str(), + case_input_index, case_node_->GetName().c_str()); + for (size_t out_index = 0; out_index < data_count_from_getnext_; ++out_index, ++case_input_index) { + if (GraphUtils::AddEdge(getnext_node->GetOutDataAnchor(out_index), + case_node_->GetInDataAnchor(case_input_index)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Failed to add data edge between %zu Data:%s to %zu Case:%s", out_index, + getnext_node->GetName().c_str(), case_input_index, case_node_->GetName().c_str()); + return FAILED; + } + } + if (getnext_sink_dynamic_dims_) { + GE_CHK_STATUS_RET(LinkGetNextToGetDynamicDims(getnext_node, shape_node_), "Failed to add link for %s.", + shape_node_->GetName().c_str()); + } + return SUCCESS; +} + /// /// @ingroup ge /// @brief Create Const node for root graph. @@ -378,7 +620,11 @@ Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) { Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) { // Const --> Case std::vector all_const_nodes; - const size_t arg_index = kCaseArgIndex + all_data_nodes_.size(); + size_t arg_index = kCaseArgIndex + all_data_nodes_.size(); + if (data_count_from_getnext_ != 0) { + arg_index = arg_index + data_count_from_getnext_ - kNumOfGetnextNode; + } + for (size_t i = 0; i < all_const_nodes_.size(); ++i) { const auto &node = all_const_nodes_[i]; const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc()); @@ -395,15 +641,33 @@ Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) { const NodePtr &data = graph->AddNode(op_desc); GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str()); if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to add edge between Const:%s to Case:%s", - data->GetName().c_str(), case_node_->GetName().c_str()); + GELOGE(FAILED, "Failed to add edge between Const:%s to Case:%s", data->GetName().c_str(), + case_node_->GetName().c_str()); return FAILED; } all_const_nodes.emplace_back(data); } + ChangeConstToData(); + all_const_nodes_.swap(all_const_nodes); + return SUCCESS; +} +void MultiBatchClonePass::ChangeConstToData() { size_t data_index = all_data_nodes_.size(); + if (data_count_from_getnext_ != 0) { + data_index = data_index + data_count_from_getnext_ - kNumOfGetnextNode; + } for (size_t i = 0; i < all_const_nodes_.size(); ++i, ++data_index) { // Trans subgraph Const to Data. + auto &const_node = all_const_nodes_[i]; + bool need_change_type = true; + if (out_control_nodes_.find(const_node) != out_control_nodes_.end()) { + GELOGD("No need to change %s to data type.", const_node->GetName().c_str()); + need_change_type = false; + break; + } + if (!need_change_type) { + continue; + } const OpDescPtr &op_desc = all_const_nodes_[i]->GetOpDesc(); op_desc->SetType(DATA); (void)op_desc->DelAttr(ATTR_NAME_WEIGHTS); // Delete weight. @@ -413,9 +677,6 @@ Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) { (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index); (void)NodeUtils::AppendInputAnchor(all_const_nodes_[i], 1); } - - all_const_nodes_.swap(all_const_nodes); - return SUCCESS; } /// @@ -461,7 +722,8 @@ Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) { } } } - + GE_CHK_STATUS_RET(LinkGetDynamicDimsToNetOutput(node), "Failed to add edge between %s to netoutput: %s.", + shape_node_->GetName().c_str(), output->GetName().c_str()); all_output_nodes_.clear(); all_output_nodes_.emplace_back(node); return SUCCESS; @@ -473,34 +735,69 @@ Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) { /// @param [in] const NodePtr &data: data in Root/Case graph. /// @return 0: SUCCESS / others: FAILED /// -Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { - auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); - auto data_name = data->GetName(); +Status MultiBatchClonePass::SetMaxShape(const NodePtr &data) { + GELOGD("Start set max shape for %s.", data->GetName().c_str()); + if (!IsGetNextType(data)) { + if (SetMaxShapeToData(data, kDataOutIndex) != SUCCESS) { + GELOGE(PARAM_INVALID, "Failed to update max shape of %s.", data->GetName().c_str()); + return PARAM_INVALID; + } + } else { + for (size_t out_anchor_index = 0; out_anchor_index < data_count_from_getnext_; ++out_anchor_index) { + if (SetMaxShapeToData(data, out_anchor_index) != SUCCESS) { + GELOGE(PARAM_INVALID, "Failed to update max shape of %s.", data->GetName().c_str()); + return PARAM_INVALID; + } + } + } + return SUCCESS; +} + +Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &node, size_t out_anchor_index) { + GELOGD("Start update max shape of %s, %zu output.", node->GetName().c_str(), out_anchor_index); + auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape(); + string data_name = node->GetName(); + if (IsGetNextType(node)) { + data_name.append("_").append(std::to_string(out_anchor_index)); + } + GELOGD("Update max shape of %s, shape dims is %s.", data_name.c_str(), + formats::JoinToString(data_shape.GetDims()).c_str()); const auto &dims = data_shape.GetDims(); - if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { - return SUCCESS; + if (!IsGetNextType(node)) { + if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { + GELOGD("No need to do anything for static data."); + return SUCCESS; + } + } else { + if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { + if (getnext_sink_dynamic_dims_) { + // need to update shape of Shape_node when getnext node has dynamic data + GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(node, out_anchor_index), "Failed to update shape of shape node"); + } + return SUCCESS; + } } - (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); + (void)AttrUtils::SetListInt(node->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); - GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex)); + GeTensorDesc tensor(NodeUtils::GetOutputDesc(*node, kDataOutIndex)); std::vector input_dims_str; for (size_t i = 0; i < batch_shapes_.size(); ++i) { auto shape = data_shape; auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); if (ret != SUCCESS) { - GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", data->GetName().c_str()); + GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", node->GetName().c_str()); return ret; } tensor.SetShape(shape); int64_t tensor_size = 0; (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + - TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + + TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + node->GetName() + ":" + std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + formats::JoinToString(tensor.GetShape().GetDims()); input_dims_str.emplace_back(input_str); } - (void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); + (void)AttrUtils::SetListStr(node->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); size_t max_shape_index = 0; int64_t max_size = 0; @@ -519,18 +816,72 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { max_shape_index = i; } } + return SetShapeToData(data_to_dynamic_info_.at(data_name).at(max_shape_index), node, data_shape, out_anchor_index); +} - return SetShapeToData(data_to_dynamic_info_.at(data_name).at(max_shape_index), data, data_shape); +/// +/// @ingroup ge +/// @brief Set max shape to Data/GetNext node in root graph. +/// @param [in] const std::vector &shapes: dims of shape. +/// @param [in] const NodePtr &data: data in Root/Case graph. +/// @param [in] GeShape &data_shape: dims of data node. +/// @param [in] size_t out_anchor_index: out anchor index of data node. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::SetShapeToData(const std::vector &shapes, const NodePtr &data, GeShape &data_shape, + size_t out_anchor_index) { + GELOGD("Start set shape to %zu out of %s.", out_anchor_index, data->GetName().c_str()); + if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to calculate the batched shape for data node %s, the shapes may not match", + data->GetName().c_str()); + return INTERNAL_ERROR; + } + + if (NodeUtils::UpdateOutputShape(*data, out_anchor_index, data_shape) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str()); + return INTERNAL_ERROR; + } + if (!IsGetNextType(data)) { + if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str()); + return INTERNAL_ERROR; + } + } else { + if (getnext_sink_dynamic_dims_) { + // need to update shape of Shape_node when getnext_sink_dynamic + GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(data, out_anchor_index), "Failed to update shape of shape node"); + } + } + + GELOGI("Update the data %s input/output shape to the max %s", data->GetName().c_str(), + formats::ShapeToString(data_shape).c_str()); + return SUCCESS; +} + +Status MultiBatchClonePass::UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index) { + GELOGD("Start update output shape of shape node insert by adapter, which is the %zu out of %s.", out_anchor_index, + node->GetName().c_str()); + auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape(); + size_t shape_index = out_anchor_index + (node->GetAllOutDataAnchors().size() / kDivisionConst); + GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(shape_index); + std::vector output_dims = {static_cast(data_shape.GetDims().size())}; + GeShape output_shape(output_dims); + output_desc.SetShape(output_shape); + if (node->GetOpDesc()->UpdateOutputDesc(shape_index, output_desc) != SUCCESS) { + GELOGE(FAILED, "Update output desc fail."); + return FAILED; + } + return SUCCESS; } /// /// @ingroup ge /// @brief Update Data node in Subgraph. /// @param [in] const NodePtr &data: data in Subgraph. -/// @param [in] size_t index: The batch index. +/// @param [in] size_t batch_index: The batch index. /// @return 0: SUCCESS / others: FAILED /// -Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index) { +Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t batch_index) { int node_index = -1; if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_INDEX, node_index)) { GELOGE(FAILED, "Failed to get index from data[%s]", data->GetName().c_str()); @@ -545,6 +896,8 @@ Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); const auto &dims = data_shape.GetDims(); + GELOGD("Start update shape of %s , batch index is %zu, dims is %s.", data->GetName().c_str(), batch_index, + formats::JoinToString(dims).c_str()); if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { return SUCCESS; } @@ -559,35 +912,77 @@ Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index } auto parent_name = data_name.substr(0, pos); - return SetShapeToData(data_to_dynamic_info_.at(parent_name).at(index), data, data_shape); + return SetShapeToData(data_to_dynamic_info_.at(parent_name).at(batch_index), data, data_shape, kDataOutIndex); } -/// -/// @ingroup ge -/// @brief Set max shape to Data node in root graph. -/// @param [in] const std::vector &shapes: dims of shape. -/// @param [in] const NodePtr &data: data in Root/Case graph. -/// @param [in] GeShape &data_shape: dims of data node. -/// @return 0: SUCCESS / others: FAILED -/// -Status MultiBatchClonePass::SetShapeToData(const vector &shapes, const NodePtr &data, GeShape &data_shape) { - // must not be error, the calc result has been checked in function InsertSwitchNForData - if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) { - return INTERNAL_ERROR; +Status MultiBatchClonePass::CreateOriGraph(const ComputeGraphPtr &graph) { + if (data_count_from_getnext_ == 0) { + GELOGD("No need to change original graph without getnext node."); + return SUCCESS; } - - if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str()); - return INTERNAL_ERROR; + GELOGD("Start change original graph: %s when exit getnext node.", graph->GetName().c_str()); + size_t data_index = all_data_nodes_.size() - kNumOfGetnextNode; + for (const auto &node : graph->GetDirectNode()) { + if (IsGetNextType(node)) { + for (size_t out_index = 0; out_index < data_count_from_getnext_; ++out_index, ++data_index) { + auto out_data_anchor = node->GetOutDataAnchor(out_index); + GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); + NodePtr data_node = CreateDataNode(graph, out_data_anchor, data_index); + GE_IF_BOOL_EXEC(data_node == nullptr, GELOGE(INTERNAL_ERROR, "Create %zu data node failed.", + out_data_anchor->GetIdx()); return INTERNAL_ERROR); + for (auto &in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_IF_BOOL_EXEC(in_anchor == nullptr, continue); + NodePtr dst_node = in_anchor->GetOwnerNode(); + if (GraphUtils::RemoveEdge(out_data_anchor, in_anchor) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to remove edge between %s to %s", node->GetName().c_str(), + dst_node->GetName().c_str()); + return INTERNAL_ERROR; + } + if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(in_anchor->GetIdx())) != + GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to add edge between %s to %s", data_node->GetName().c_str(), + dst_node->GetName().c_str()); + return INTERNAL_ERROR; + } + } + } + if (graph->RemoveNode(node) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Remove node %s failed!", node->GetName().c_str()); + return GRAPH_FAILED; + } + break; + } } + return SUCCESS; +} - if (NodeUtils::UpdateOutputShape(*data, kDataOutIndex, data_shape) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str()); - return INTERNAL_ERROR; +NodePtr MultiBatchClonePass::CreateDataNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, + size_t data_index) { + size_t out_anchor_index = out_data_anchor->GetIdx(); + std::string node_name = out_data_anchor->GetOwnerNode()->GetName() + "_" + std::to_string(out_anchor_index); + OpDescPtr op_desc = MakeShared(node_name, DATA); + if (op_desc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create data node failed."); + return nullptr; } + (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index); - GELOGI("Update %s input/output shape to %s", data->GetName().c_str(), formats::ShapeToString(data_shape).c_str()); - return SUCCESS; + OpDescPtr getnext_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); + if (getnext_op_desc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Op desc of %s is nullptr.", out_data_anchor->GetOwnerNode()->GetName().c_str()); + return nullptr; + } + if (op_desc->AddInputDesc(getnext_op_desc->GetOutputDesc(out_anchor_index)) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add %s input desc failed.", op_desc->GetName().c_str()); + return nullptr; + } + if (op_desc->AddOutputDesc(getnext_op_desc->GetOutputDesc(out_anchor_index)) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add %s output desc failed.", op_desc->GetName().c_str()); + return nullptr; + } + NodePtr data_node = graph->AddNode(op_desc); + GELOGD("Success create %s node.", data_node->GetName().c_str()); + return data_node; } /// @@ -598,17 +993,14 @@ Status MultiBatchClonePass::SetShapeToData(const vector &shapes, const /// @return 0: SUCCESS / others: FAILED /// Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch) { + GELOGD("Start create subgraphs for %s.", graph->GetName().c_str()); const auto &op_desc = case_node_->GetOpDesc(); for (size_t i = 0; i < batch_shapes_.size(); ++i) { std::vector input_nodes; std::vector output_nodes; const std::string postfix = kMultiBatchNodePostfix + std::to_string(i); ComputeGraphPtr subgraph = (i == 0) ? branch : GraphUtils::CloneGraph(branch, postfix, input_nodes, output_nodes); - if (subgraph == nullptr) { - GELOGE(FAILED, "Create multi-batch case node failed"); - return FAILED; - } - + GE_IF_BOOL_EXEC(subgraph == nullptr, GELOGE(FAILED, "Create multi-batch case node failed"); return FAILED); subgraph->SetName("Batch_" + std::to_string(i)); subgraph->SetParentNode(case_node_); subgraph->SetParentGraph(graph); @@ -621,6 +1013,7 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const op_desc->AddSubgraphName(key_name); op_desc->SetSubgraphInstanceName(i, subgraph->GetName()); + GELOGD("The %s has %zu input, %zu output.", subgraph->GetName().c_str(), input_nodes.size(), output_nodes.size()); for (const auto &data : input_nodes) { GE_CHK_STATUS_RET(UpdateSubgraphData(data, i), "Update %s failed", subgraph->GetName().c_str()); } @@ -666,6 +1059,7 @@ Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) { /// @return 0: SUCCESS / others: FAILED /// Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) { + GELOGD("Start prune direct output."); const auto &func_desc = case_node_->GetOpDesc(); uint32_t unused_num = 0; uint32_t output_num = func_desc->GetOutputsSize(); @@ -710,6 +1104,7 @@ Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) { /// Status MultiBatchClonePass::UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num) { if (unused_num == 0) { + GELOGD("No need to update output tensor."); return SUCCESS; } diff --git a/ge/graph/passes/multi_batch_clone_pass.h b/ge/graph/passes/multi_batch_clone_pass.h index ee137b5a..66e92892 100755 --- a/ge/graph/passes/multi_batch_clone_pass.h +++ b/ge/graph/passes/multi_batch_clone_pass.h @@ -36,6 +36,7 @@ class MultiBatchClonePass : public GraphPass { /// @return 0: SUCCESS / others: FAILED /// Status CollectIoNodes(const ComputeGraphPtr &graph); + Status InitParamsOfGetNext(const NodePtr &node); /// /// @ingroup ge @@ -49,10 +50,12 @@ class MultiBatchClonePass : public GraphPass { /// @ingroup ge /// @brief Create index data node for root graph. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. - /// @param [in] NodePtr node: index data node. + /// @param [in] NodePtr shape_node: index data node, DATA or GETDYNAMICDIMS type. /// @return 0: SUCCESS / others: FAILED /// - Status CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &node); + Status CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &shape_node); + + Status CreateGetDynamicDimsNode(const ComputeGraphPtr &graph, NodePtr &shape_node); /// /// @ingroup ge @@ -70,6 +73,9 @@ class MultiBatchClonePass : public GraphPass { /// @return 0: SUCCESS / others: FAILED /// Status CreateIndexNode(const ComputeGraphPtr &graph); + Status AddAttrForGetDynamicDims(const NodePtr &shape_node); + Status LinkGetNextToGetDynamicDims(const NodePtr &getnext_node, const NodePtr &shape_node); + Status LinkGetDynamicDimsToNetOutput(const NodePtr &output_node); /// /// @ingroup ge @@ -78,39 +84,54 @@ class MultiBatchClonePass : public GraphPass { /// @return 0: SUCCESS / others: FAILED /// Status CreateInputNode(const ComputeGraphPtr &graph); + Status LinkEdgeForGetNext(const NodePtr &getnext_node, size_t &case_input_index); /// /// @ingroup ge - /// @brief Create Const node for root graph. - /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. + /// @brief Set max shape to Data node in root graph. + /// @param [in] const NodePtr &data: data in Root/Case graph. /// @return 0: SUCCESS / others: FAILED /// - Status CreateConstNode(const ComputeGraphPtr &graph); + Status SetMaxShape(const NodePtr &data); + Status SetMaxShapeToData(const NodePtr &node, size_t out_anchor_index); + /// + /// @ingroup ge + /// @brief Set max shape to Data/GetNext node in root graph. + /// @param [in] const std::vector &shapes: dims of shape. + /// @param [in] const NodePtr &data: data in Root/Case graph. + /// @param [in] GeShape &data_shape: dims of data node. + /// @param [in] size_t out_anchor_index: out anchor index of data node. + /// @return 0: SUCCESS / others: FAILED + /// + Status SetShapeToData(const std::vector &shapes, const NodePtr &data, GeShape &data_shape, + size_t out_anchor_index); + Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index); /// /// @ingroup ge - /// @brief Create output node for root graph. + /// @brief Create Const node for root graph. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. /// @return 0: SUCCESS / others: FAILED /// - Status CreateOutputNode(const ComputeGraphPtr &graph); + Status CreateConstNode(const ComputeGraphPtr &graph); + void ChangeConstToData(); /// /// @ingroup ge - /// @brief Set max shape to Data node in root graph. - /// @param [in] const NodePtr &data: data in Root/Case graph. + /// @brief Create output node for root graph. + /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. /// @return 0: SUCCESS / others: FAILED /// - Status SetMaxShapeToData(const NodePtr &data); + Status CreateOutputNode(const ComputeGraphPtr &graph); /// /// @ingroup ge /// @brief Update Data node in Subgraph. /// @param [in] const NodePtr &data: data in Subgraph. - /// @param [in] size_t index: The batch index. + /// @param [in] size_t batch_index: The batch index. /// @return 0: SUCCESS / others: FAILED /// - Status UpdateSubgraphData(const NodePtr &data, size_t index); + Status UpdateSubgraphData(const NodePtr &data, size_t batch_index); /// /// @ingroup ge @@ -122,13 +143,12 @@ class MultiBatchClonePass : public GraphPass { /// /// @ingroup ge - /// @brief Set max shape to Data node in root graph. - /// @param [in] const std::vector &shapes: dims of shape. - /// @param [in] const NodePtr &data: data in Root/Case graph. - /// @param [in] GeShape &data_shape: dims of data node. + /// @brief Create nodes for root graph. + /// @param [in] const ComputeGraphPtr &graph: Original graph. /// @return 0: SUCCESS / others: FAILED /// - Status SetShapeToData(const std::vector &shapes, const NodePtr &data, GeShape &data_shape); + Status CreateOriGraph(const ComputeGraphPtr &graph); + NodePtr CreateDataNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, size_t data_index); /// /// @ingroup ge @@ -168,6 +188,10 @@ class MultiBatchClonePass : public GraphPass { std::map>> data_to_dynamic_info_; NodePtr case_node_; + size_t data_count_from_getnext_ = 0; + bool getnext_sink_dynamic_dims_ = false; + NodePtr shape_node_; + std::set out_control_nodes_; }; } // namespace ge #endif // GE_GRAPH_PASSES_MULTI_BATCH_CLONE_PASS_H_ diff --git a/ge/graph/passes/unused_args_clean_pass.cc b/ge/graph/passes/unused_args_clean_pass.cc index 83fd0438..ec66b129 100755 --- a/ge/graph/passes/unused_args_clean_pass.cc +++ b/ge/graph/passes/unused_args_clean_pass.cc @@ -204,6 +204,10 @@ Status UnusedArgsCleanPass::RemoveInputTensor(const mapGetName().c_str(), func_node->GetName().c_str()); + if (out_node->GetInDataNodes().size() == 0 && out_node->GetOutAllNodes().size() == 0) { + GE_CHK_GRAPH_STATUS_RET(out_node->GetOwnerComputeGraph()->RemoveNode(out_node), "Remove node failed: %s", + out_node->GetName().c_str()); + } return SUCCESS; } } // namespace ge \ No newline at end of file diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index c8880b2e..5506435e 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -1692,13 +1692,11 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) { } Status ProcessMultiBatch(ComputeGraphPtr &graph) { - if (GetLocalOmgContext().dynamic_node_type.empty()) { - const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN"); - if (multi_batch_with_switchn == nullptr) { - PassManager pass_manager; - GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); - return pass_manager.Run(graph); - } + const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN"); + if (multi_batch_with_switchn == nullptr) { + PassManager pass_manager; + GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); + return pass_manager.Run(graph); } if (!GetLocalOmgContext().need_multi_batch) { GELOGI("No need to process_multi for no_train graph."); diff --git a/ge/graph/preprocess/multi_batch_options.cc b/ge/graph/preprocess/multi_batch_options.cc index c26b08bc..aba2b88d 100644 --- a/ge/graph/preprocess/multi_batch_options.cc +++ b/ge/graph/preprocess/multi_batch_options.cc @@ -99,9 +99,8 @@ Status DistinguishGetNextAndData(ComputeGraphPtr &graph, vector &data_n } GELOGI("Data count is %zu, getnext nosink count is %zu, getnext sink count is %zu.", data_nodes.size(), getnext_nosink_nodes.size(), getnext_sink_nodes.size()); - GE_IF_BOOL_EXEC(!graph->SetExtAttr(kExtAttrDataNodes, data_nodes), GELOGW("Set data nodes attr failed.");) - GE_IF_BOOL_EXEC(!graph->SetExtAttr(kExtAttrGetNextNoSink, getnext_nosink_nodes), - GELOGW("Set getnext nosink nodes attr failed.");) + GetLocalOmgContext().data_nodes = data_nodes; + GetLocalOmgContext().getnext_nosink_nodes = getnext_nosink_nodes; return SUCCESS; } diff --git a/inc/framework/omg/omg_inner_types.h b/inc/framework/omg/omg_inner_types.h index dab79053..1049b6b5 100644 --- a/inc/framework/omg/omg_inner_types.h +++ b/inc/framework/omg/omg_inner_types.h @@ -26,6 +26,7 @@ #include #include "framework/common/fmk_error_codes.h" #include "register/register_fmk_types.h" +#include "graph/node.h" using domi::DOMI_TENSOR_ND; using domi::DOMI_TENSOR_RESERVED; @@ -120,6 +121,8 @@ struct OmgContext { std::vector> user_real_input_dims; std::vector cur_dynamic_dims; bool need_multi_batch = false; + std::vector data_nodes; + std::vector getnext_nosink_nodes; }; } // namespace ge diff --git a/metadef b/metadef index 30cf97ba..fe37bc34 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 30cf97ba0c9a70ade0d9df92695c9dcd671316f6 +Subproject commit fe37bc343ea52c76d35e9e9ec83cea0151bfa900 diff --git a/parser b/parser index e338bc22..336cd310 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit e338bc2200bed9f11f6e665e6ad37a3a97906354 +Subproject commit 336cd3107253d3fe41cfb9fec2db62b5f3d8a33b diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index dcf389c0..db725dfb 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -627,6 +627,7 @@ set(PASS_TEST_FILES "graph/passes/net_output_pass_unittest.cc" "graph/passes/no_use_reshape_remove_pass_unittest.cc" "graph/passes/infershape_pass_unittest.cc" + "graph/passes/multi_batch_clone_pass_unittest.cc" ) set(KERNEL_TEST_FILES diff --git a/tests/ut/ge/graph/load/davinci_model_unittest.cc b/tests/ut/ge/graph/load/davinci_model_unittest.cc index a9efab3d..9e51585b 100644 --- a/tests/ut/ge/graph/load/davinci_model_unittest.cc +++ b/tests/ut/ge/graph/load/davinci_model_unittest.cc @@ -32,6 +32,18 @@ class UtestDavinciModel : public testing::Test { void SetUp() {} void TearDown() {} + public: + NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { + GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); + auto op_desc = std::make_shared(name, type); + for (auto i = 0; i < in_num; ++i) { + op_desc->AddInputDesc(test_desc); + } + for (auto i = 0; i < out_num; ++i) { + op_desc->AddOutputDesc(test_desc); + } + return graph->AddNode(op_desc); + } }; TEST_F(UtestDavinciModel, init_success) { @@ -324,5 +336,94 @@ TEST_F(UtestDavinciModel, SyncVarData_test) { EXPECT_NE(model.SyncVarData(), SUCCESS); } +TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ1) { + DavinciModel model(0, nullptr); + model.ge_model_ = make_shared(); + ComputeGraphPtr graph = make_shared("default"); + + GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); + OpDescPtr op_output = CreateOpDesc("output_ascend_mbatch_batch_1", NETOUTPUT); + op_output->AddInputDesc(tensor); + op_output->SetInputOffset({1024}); + NodePtr node_output = graph->AddNode(op_output); + EXPECT_EQ(model.InitRealSizeAndShapeInfo(graph, node_output), SUCCESS); +} + +TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ2) { + DavinciModel model(0, nullptr); + ComputeGraphPtr graph = std::make_shared("test_graph"); + + OpDescPtr data1 = CreateOpDesc("data1", DATA); + GeTensorDesc shape_desc(GeShape({4,3,224,224}), FORMAT_NCHW, DT_FLOAT); + data1->AddInputDesc(shape_desc); + data1->AddOutputDesc(shape_desc); + NodePtr data1_node = graph->AddNode(data1); + + OpDescPtr case_node = CreateOpDesc("case1", CASE); + GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); + case_node->AddInputDesc(tensor); + case_node->AddOutputDesc(tensor); + NodePtr case1_node = graph->AddNode(case_node); + + OpDescPtr output = CreateOpDesc("output1", NETOUTPUT); + output->AddInputDesc(tensor); + output->SetSrcName( { "case1" } ); + output->SetSrcIndex( { 0 } ); + NodePtr output_node = graph->AddNode(output); + + GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), case1_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(case1_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); + + (void)AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, "1;2;4;8"); + (void)AttrUtils::SetBool(case_node, ATTR_INSERT_BY_MBATCH, true); + + model.is_getnext_sink_dynamic_ = false; + model.is_online_infer_dynamic_ = true; + auto ret = model.InitRealSizeAndShapeInfo(graph, output_node); + // GetGearAndRealOutShapeInfo without ATTR_NAME_DYNAMIC_OUTPUT_DIMS + EXPECT_EQ(ret, SUCCESS); + vector dynamic_output_dims = {"0,0,1,1,0,2,2,0,4,3,0,8"}; + (void)AttrUtils::SetListStr(output_node->GetOpDesc(), ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_dims); + ret = model.InitRealSizeAndShapeInfo(graph, output_node); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ3) { + DavinciModel model(0, nullptr); + ComputeGraphPtr graph = std::make_shared("test_graph"); + + OpDescPtr data1 = CreateOpDesc("data1", DATA); + GeTensorDesc shape_desc(GeShape({4,3,224,224}), FORMAT_NCHW, DT_FLOAT); + data1->AddInputDesc(shape_desc); + data1->AddOutputDesc(shape_desc); + NodePtr data1_node = graph->AddNode(data1); + + OpDescPtr shape_node = CreateOpDesc("ascend_mbatch_get_dynamic_dims_node", GETDYNAMICDIMS); + GeTensorDesc in_tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); + GeTensorDesc out_tensor(GeShape({4,3}), FORMAT_NCHW, DT_FLOAT); + shape_node->AddInputDesc(in_tensor); + shape_node->AddOutputDesc(out_tensor); + NodePtr get_dynamic_dims_node = graph->AddNode(shape_node); + + OpDescPtr output = CreateOpDesc("output1", NETOUTPUT); + GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); + output->AddInputDesc(tensor); + output->SetSrcName( { "data1", "ascend_mbatch_get_dynamic_dims_node" } ); + output->SetSrcIndex( { 0, 1 } ); + NodePtr output_node = graph->AddNode(output); + GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(get_dynamic_dims_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(1)); + + (void)AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, "1,3;;4,3;,3"); + + model.is_getnext_sink_dynamic_ = true; + model.is_online_infer_dynamic_ = false; + auto ret = model.InitRealSizeAndShapeInfo(graph, output_node); + EXPECT_EQ(ret, SUCCESS); + model.runtime_param_.mem_base = (uint8_t *)0x08000000; + model.runtime_param_.mem_size = 4; + ret = model.InitRealSizeAndShapeInfo(graph, output_node); + EXPECT_EQ(ret, SUCCESS); +} } // namespace ge diff --git a/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc b/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc new file mode 100644 index 00000000..b1cd6d4d --- /dev/null +++ b/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc @@ -0,0 +1,247 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/multi_batch_clone_pass.h" + +#include +#include +#include + +#include "inc/pass_manager.h" +#include "graph/utils/tensor_utils.h" +#include "graph/common/local_context.h" +#include "graph/passes/multi_batch_pass.h" +#include "graph/preprocess/multi_batch_copy_graph.h" +#include "graph/preprocess/insert_op/util_insert_aipp_op.h" +#include "framework/omg/omg_inner_types.h" +#include "register/op_registry.h" + + +namespace ge{ +class UtestMultiBatchClonePass : public testing::Test { +protected: + void SetUp() { + SetLocalOmgContext(domi::GetContext()); + GetLocalOmgContext().dynamic_image_size.clear(); + GetLocalOmgContext().dynamic_batch_size.clear(); + } + void TearDown() { + GetLocalOmgContext().dynamic_image_size.clear(); + GetLocalOmgContext().dynamic_batch_size.clear(); + GetLocalOmgContext().dynamic_node_type.clear(); + } + +public: + NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { + GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); + auto op_desc = std::make_shared(name, type); + for (auto i = 0; i < in_num; ++i) { + op_desc->AddInputDesc(test_desc); + } + for (auto i = 0; i < out_num; ++i) { + op_desc->AddOutputDesc(test_desc); + } + return graph->AddNode(op_desc); + } + + NodePtr MakeConstNode(const ComputeGraphPtr &graph) { + static uint32_t index = 0; + GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); + auto op_desc = std::make_shared("dynamic_const_" + std::to_string(index++), "Const"); + op_desc->AddOutputDesc(test_desc); + return graph->AddNode(op_desc); + } + + void make_original_graph(const ComputeGraphPtr &graph) { + auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); + { + auto data1 = MakeNode(graph, 1, 1, "data", "Data"); + GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT); + data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0); + GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector{-1,3,224,224})}; + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); + auto const1 = MakeConstNode(graph); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); + auto const2 = MakeConstNode(graph); + GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); + } + + auto bn_conv1 = MakeNode(graph, 4, 1, "bn_conv1", "BNInference"); + { + GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(0)); + auto const1 = MakeConstNode(graph); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(1)); + auto const2 = MakeConstNode(graph); + GraphUtils::AddEdge(const2->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(2)); + auto const3= MakeConstNode(graph); + GraphUtils::AddEdge(const3->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(3)); + } + + auto scale_conv1 = MakeNode(graph, 4, 1, "scale1", "Scale"); + { + GraphUtils::AddEdge(bn_conv1->GetOutDataAnchor(0), scale_conv1->GetInDataAnchor(0)); + auto const1 = MakeConstNode(graph); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), scale_conv1->GetInDataAnchor(1)); + auto const2 = MakeConstNode(graph); + GraphUtils::AddEdge(const2->GetOutDataAnchor(0), scale_conv1->GetInDataAnchor(2)); + } + + auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput"); + GraphUtils::AddEdge(scale_conv1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); + } + + void GraphWithJustData(const ComputeGraphPtr &graph) { + auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); + { + auto data1 = MakeNode(graph, 1, 1, "data", "Data"); + GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT); + data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0); + GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector{-1,3,224,224})}; + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); + auto const1 = MakeConstNode(graph); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); + auto const2 = MakeConstNode(graph); + GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); + } + + auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput"); + GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); + } + + void GraphWithGetNextNosink(const ComputeGraphPtr &graph) { + auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); + { + auto data1 = MakeNode(graph, 1, 1, "IteratorGetNext_data", "Data"); + GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT); + data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0); + GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector{-1,3,224,224})}; + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); + auto const1 = MakeConstNode(graph); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); + auto const2 = MakeConstNode(graph); + GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); + } + + auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput"); + GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); + } + + // getnext has one data and has one out of shape + void GraphWithGetNextSink(const ComputeGraphPtr &graph) { + auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); + { + auto data1 = MakeNode(graph, 1, 2, "data", "IteratorV2"); + GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT); + GeTensorDesc shape_desc(GeShape({4,3,224,224}), FORMAT_NCHW, DT_FLOAT); + data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + data1->GetOpDesc()->UpdateOutputDesc(1, shape_desc); + AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0); + GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector{-1,3,224,224})}; + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); + auto identity = MakeNode(graph, 1, 0, "identity", "Identity"); + GraphUtils::AddEdge(data1->GetOutDataAnchor(1), identity->GetInDataAnchor(0)); + auto const1 = MakeConstNode(graph); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); + auto const2 = MakeConstNode(graph); + GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); + } + + auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput"); + GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); + } +}; + +// graph is nullptr +TEST_F(UtestMultiBatchClonePass, graph_nullptr) { + PassManager pass_manager; + pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass); + ComputeGraphPtr graph; + EXPECT_EQ(pass_manager.Run(graph), PARAM_INVALID); +} + +// graph with subgraph +TEST_F(UtestMultiBatchClonePass, graph_with_subgraph) { + PassManager pass_manager; + pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass); + ComputeGraphPtr graph = std::make_shared("test_graph"); + make_original_graph(graph); + EXPECT_EQ(pass_manager.Run(graph), SUCCESS); + + ComputeGraphPtr owner = std::make_shared("test_owner"); + auto func_node = MakeNode(owner, 3, 1, "test_if", "If"); + graph->SetParentNode(func_node); + graph->SetParentGraph(owner); + EXPECT_EQ(pass_manager.Run(graph), SUCCESS); +} + +//graph is uncompute graph, not need to do multi batch +TEST_F(UtestMultiBatchClonePass, uncompute_graph) { + MultiBatchClonePass multi_batch_clone; + ComputeGraphPtr graph = std::make_shared("test_graph"); + make_original_graph(graph); + GetLocalOmgContext().need_multi_batch = false; + EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); +} + + +//compute_graph with data from DATA +TEST_F(UtestMultiBatchClonePass, compute_graph_with_data) { + MultiBatchClonePass multi_batch_clone; + ComputeGraphPtr graph = std::make_shared("test_graph"); + GraphWithJustData(graph); + GetLocalOmgContext().need_multi_batch = true; + EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); + GetLocalOmgContext().dynamic_node_type = DATA; + GetLocalOmgContext().dynamic_dims = "1;2;4;8"; + EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); + EXPECT_EQ(GetLocalOmgContext().data_nodes.size(), 1); +} + +//compute_graph with data from GetNext_nosink +TEST_F(UtestMultiBatchClonePass, compute_graph_with_getnext_nosink) { + MultiBatchClonePass multi_batch_clone; + ComputeGraphPtr graph = std::make_shared("test_graph"); + GraphWithGetNextNosink(graph); + GetLocalOmgContext().need_multi_batch = true; + GetLocalOmgContext().dynamic_node_type = GETNEXT; + GetLocalOmgContext().dynamic_dims = "1;2;4;8"; + EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); + EXPECT_EQ(GetLocalOmgContext().getnext_nosink_nodes.size(), 1); +} + +//compute_graph with data from GetNext_nosink +TEST_F(UtestMultiBatchClonePass, compute_graph_with_getnext_sink) { + MultiBatchClonePass multi_batch_clone; + ComputeGraphPtr graph = std::make_shared("test_graph"); + GraphWithGetNextSink(graph); + GetLocalOmgContext().need_multi_batch = true; + GetLocalOmgContext().dynamic_node_type = GETNEXT; + GetLocalOmgContext().dynamic_dims = "1;2;4;8"; + EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS); + EXPECT_EQ(GetLocalOmgContext().getnext_nosink_nodes.size(), 0); +} + +}