From 0899b516331e46e7bc5e7e84be1588a4eed0976b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=A3=8A?= Date: Wed, 28 Apr 2021 14:45:31 +0800 Subject: [PATCH 1/2] parse topic type --- .../task_info/kernel_ex_task_info.cc | 9 +++ .../task_info/kernel_ex_task_info.h | 1 + .../task_info/kernel_task_info.cc | 9 +++ .../task_info/kernel_task_info.h | 1 + .../node_executor/aicpu/aicpu_ext_info.cc | 56 +++++++++++++++++++ .../node_executor/aicpu/aicpu_ext_info.h | 5 ++ .../load/kernel_ex_task_info_unittest.cc | 54 ++++++++++++++++++ .../fwkacllib/inc/cce/fwk_adpt_struct.h | 9 +++ third_party/fwkacllib/inc/runtime/kernel.h | 5 ++ 9 files changed, 149 insertions(+) diff --git a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc index e2f600b3..e1097149 100644 --- a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc +++ b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc @@ -53,6 +53,7 @@ Status KernelExTaskInfo::InitTaskExtInfo(const std::string &ext_info, const OpDe "Parse kernel ext info failed, kernel_ext_info_size=%zu.", ext_info.size()); GE_CHK_STATUS_RET(ext_handle->UpdateExecuteMode(true), "UpdateExecuteMode failed."); GELOGD("Update aicpu_task ext_info bit_map execute mode to 1."); + topic_type_flag_ = ext_handle->GetTopicTypeFlag(); bool all_shape = false; (void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape); @@ -406,6 +407,14 @@ Status KernelExTaskInfo::CopyTaskInfo(const domi::KernelExDef &kernel_def, const Status KernelExTaskInfo::Distribute() { GELOGI("KernelExTaskInfo Distribute Start."); + // Use the fifth and sixth bits of dump_flag_ indicate the value of topic_type. + // xxxxxxxx xxxxxxxx xxxxxxxx xx00xxxx: DEVICE_ONLY + // xxxxxxxx xxxxxxxx xxxxxxxx xx01xxxx: DEVICE_FIRST + // xxxxxxxx xxxxxxxx xxxxxxxx xx10xxxx: HOST_ONLY + // xxxxxxxx xxxxxxxx xxxxxxxx xx11xxxx: HOST_FIRST + if (topic_type_flag_ > 0) { + dump_flag_ = dump_flag_ | topic_type_flag_; + } rtError_t rt_ret = rtKernelLaunchEx(kernel_buf_, kernel_buf_size_, dump_flag_, stream_); if (rt_ret != RT_ERROR_NONE) { REPORT_CALL_ERROR("E19999", "Call rtKernelLaunchEx failed, ret:0x%X", diff --git a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h index 71153c31..bcc17168 100644 --- a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h +++ b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h @@ -76,6 +76,7 @@ class KernelExTaskInfo : public TaskInfo { vector io_addrs_; uint32_t args_offset_ = 0; int64_t fixed_addr_offset_ = 0; + int32_t topic_type_flag_ = -1; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_KERNEL_EX_TASK_INFO_H_ diff --git a/ge/graph/load/model_manager/task_info/kernel_task_info.cc b/ge/graph/load/model_manager/task_info/kernel_task_info.cc index 82c3e286..56ce988c 100755 --- a/ge/graph/load/model_manager/task_info/kernel_task_info.cc +++ b/ge/graph/load/model_manager/task_info/kernel_task_info.cc @@ -431,6 +431,14 @@ Status KernelTaskInfo::Distribute() { int64_t env_flag = (res == EN_OK) ? strtol(skt_enable_env, nullptr, kBaseInt) : kStrtolFail; bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) { + if (topic_type_flag_ > 0) { + // Use the fifth and sixth bits of dump_flag_ indicate the value of topic_type. + // xxxxxxxx xxxxxxxx xxxxxxxx xx00xxxx: DEVICE_ONLY + // xxxxxxxx xxxxxxxx xxxxxxxx xx01xxxx: DEVICE_FIRST + // xxxxxxxx xxxxxxxx xxxxxxxx xx10xxxx: HOST_ONLY + // xxxxxxxx xxxxxxxx xxxxxxxx xx11xxxx: HOST_FIRST + dump_flag_ = dump_flag_ | topic_type_flag_; + } GELOGI("distribute task info kernel_type %d, flag %d", kernel_type_, dump_flag_); // blockDim is reserved parameter, set to 1 rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(so_name_.c_str()), @@ -1117,6 +1125,7 @@ Status KernelTaskInfo::InitAicpuTaskExtInfo(const std::string &ext_info) { GELOGD("Update aicpu_task ext_info session_info session_id is %lu", davinci_model_->GetSessionId()); GE_CHK_STATUS_RET(ext_handle->UpdateExecuteMode(true), "UpdateExecuteMode failed."); GELOGD("Update aicpu_task ext_info bit_map execute mode to 1."); + topic_type_flag_ = ext_handle->GetTopicTypeFlag(); bool all_shape = false; (void)AttrUtils::GetBool(op_desc_, kAicpuAllshape, all_shape); diff --git a/ge/graph/load/model_manager/task_info/kernel_task_info.h b/ge/graph/load/model_manager/task_info/kernel_task_info.h index 4156c511..83df8736 100644 --- a/ge/graph/load/model_manager/task_info/kernel_task_info.h +++ b/ge/graph/load/model_manager/task_info/kernel_task_info.h @@ -169,6 +169,7 @@ class KernelTaskInfo : public TaskInfo { uint16_t io_addr_offset_ = 0; bool l2_buffer_on_ = false; bool call_save_dump_ = false; + int32_t topic_type_flag_ = -1; // aicpu ext_info device mem void *aicpu_ext_info_addr_ = nullptr; diff --git a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc index b6c48157..17d5cf8b 100644 --- a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc +++ b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc @@ -24,6 +24,12 @@ namespace hybrid { namespace { // if dim count is not reach kMaxShapeDims(8), use INT64_MIN to mark dim end. constexpr int64_t kDimEndFlag = INT64_MIN; +const std::map kTopicTypeToRtsFlagMap { + {static_cast(aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_ONLY), 0}, + {static_cast(aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_FIRST), RT_KERNEL_DEVICE_FIRST}, + {static_cast(aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_ONLY), RT_KERNEL_HOST_ONLY}, + {static_cast(aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_FIRST), RT_KERNEL_HOST_FIRST} +}; } Status AicpuExtInfoHandler::Parse(const std::string &ext_info) { @@ -72,6 +78,9 @@ Status AicpuExtInfoHandler::Parse(const std::string &ext_info) { case aicpu::FWKAdapter::FWK_ADPT_EXT_UPDATE_ADDR: GE_CHK_STATUS_RET(ParseExtUpdateAddr(aicpu_ext_info), "[Parse][ExtUpdateAddr] failed."); break; + case aicpu::FWKAdapter::FWK_ADPT_EXT_TOPIC_TYPE: + GE_CHK_STATUS_RET(ParseExtTopicType(aicpu_ext_info), "[Parse][ExtTopicType] failed."); + break; default: GELOGD("Node[%s] ignore infoType=%d, infoLen=%u.", node_name_.c_str(), aicpu_ext_info->infoType, aicpu_ext_info->infoLen); @@ -207,6 +216,44 @@ Status AicpuExtInfoHandler::ParseExtUpdateAddr(AicpuExtInfo *aicpu_ext_info) { return SUCCESS; } +Status AicpuExtInfoHandler::ParseExtTopicType(AicpuExtInfo *aicpu_ext_info) { + if (aicpu_ext_info->infoLen != sizeof(int32_t)) { + REPORT_INNER_ERROR("E19999", + "Node[%s] parse topic_type info failed as infoLen must be %zu but %u.", + node_name_.c_str(), sizeof(int32_t), aicpu_ext_info->infoLen); + GELOGE(ACL_ERROR_GE_PARAM_INVALID, + "[Check][DataLen]Node[%s] parse topic_type info failed as infoLen must be %zu but %u.", + node_name_.c_str(), sizeof(int32_t), aicpu_ext_info->infoLen); + return ACL_ERROR_GE_PARAM_INVALID; + } + GE_CHECK_NOTNULL(aicpu_ext_info->infoMsg); + auto type = *reinterpret_cast(aicpu_ext_info->infoMsg); + + topic_type_flag_ = TopicTypeToRtsFlag(type); + if (topic_type_flag_ == -1) { + REPORT_INNER_ERROR("E19999", "Node[%s] parse ext topic type failed as need %d %d %d %d but %d.", + node_name_.c_str(), + aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_ONLY, + aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_FIRST, + aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_ONLY, + aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_FIRST, + type); + GELOGE(ACL_ERROR_GE_PARAM_INVALID, + "[Check][Type]Node[%s] parse ext shape type failed as need %d %d %d %d but %d.", + node_name_.c_str(), + aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_ONLY, + aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_FIRST, + aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_ONLY, + aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_FIRST, + type); + return ACL_ERROR_GE_PARAM_INVALID; + } + + GELOGI("Node[%s] parse ext topic type info success infoLen=%u, topic_type=%d, rts_flag=%d.", + node_name_.c_str(), aicpu_ext_info->infoLen, type, topic_type_flag_); + return SUCCESS; +} + Status AicpuExtInfoHandler::UpdateExecuteMode(bool flag) { if (bit_map_ == nullptr) { GELOGD("There is no bit_map in ext_info, no need update."); @@ -341,5 +388,14 @@ void AicpuExtInfoHandler::GetShapeAndType(const AicpuShapeAndType *shape_and_typ data_type = static_cast(shape_and_type->type); shape = GeShape(dims); } + +int32_t AicpuExtInfoHandler::TopicTypeToRtsFlag(int32_t topic_type) { + auto it = kTopicTypeToRtsFlagMap.find(topic_type); + if (it != kTopicTypeToRtsFlagMap.end()) { + return it->second; + } + + return -1; +} } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h index 01092204..46fb7c05 100644 --- a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h +++ b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h @@ -62,6 +62,7 @@ class AicpuExtInfoHandler { Status GetOutputShapeAndType(uint32_t output_index, GeShape &shape, DataType &data_type); bool IsNeedRefreshIOAddr(); + int32_t GetTopicTypeFlag() const { return topic_type_flag_; } private: @@ -71,6 +72,7 @@ class AicpuExtInfoHandler { Status ParseExtSessionInfo(AicpuExtInfo *aicpu_ext_info); Status ParseExtBitMap(AicpuExtInfo *aicpu_ext_info); Status ParseExtUpdateAddr(AicpuExtInfo *aicpu_ext_info); + Status ParseExtTopicType(AicpuExtInfo *aicpu_ext_info); static Status UpdateShapeAndType(const GeShape &shape, DataType data_type, @@ -81,6 +83,8 @@ class AicpuExtInfoHandler { DataType &data_type); private: + int32_t TopicTypeToRtsFlag(int32_t topic_type); + const std::string node_name_; const uint32_t input_num_; const uint32_t output_num_; @@ -88,6 +92,7 @@ class AicpuExtInfoHandler { AicpuSessionInfo *session_info_ = nullptr; uint64_t *bit_map_ = nullptr; uint32_t *update_addr_ = nullptr; + int32_t topic_type_flag_ = -1; std::unique_ptr ext_info_; size_t ext_info_len_ = 0; diff --git a/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc b/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc index 44d4d042..d0974f11 100644 --- a/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc +++ b/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc @@ -154,4 +154,58 @@ TEST_F(UtestKernelExTaskInfo, parse_update_addr) { KernelExTaskInfo kernel_ex_task_info; EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); } + +TEST_F(UtestKernelExTaskInfo, parse_topic_type_success_1) { + const string ext_info = {7,0,0,0,4,0,0,0,0,0,0,0}; + const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); + AttrUtils::SetBool(op_desc, "_AllShape", true); + + KernelExTaskInfo kernel_ex_task_info; + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); +} + +TEST_F(UtestKernelExTaskInfo, parse_topic_type_success_2) { + const string ext_info = {7,0,0,0,4,0,0,0,1,0,0,0}; + const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); + AttrUtils::SetBool(op_desc, "_AllShape", true); + + KernelExTaskInfo kernel_ex_task_info; + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); +} + +TEST_F(UtestKernelExTaskInfo, parse_topic_type_success_3) { + const string ext_info = {7,0,0,0,4,0,0,0,2,0,0,0}; + const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); + AttrUtils::SetBool(op_desc, "_AllShape", true); + + KernelExTaskInfo kernel_ex_task_info; + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); +} + +TEST_F(UtestKernelExTaskInfo, parse_topic_type_success_4) { + const string ext_info = {7,0,0,0,4,0,0,0,3,0,0,0}; + const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); + AttrUtils::SetBool(op_desc, "_AllShape", true); + + KernelExTaskInfo kernel_ex_task_info; + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); +} + +TEST_F(UtestKernelExTaskInfo, parse_topic_type_failed_1) { + const string ext_info = {7,0,0,0,4,0,0,0,4,0,0,0}; + const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); + AttrUtils::SetBool(op_desc, "_AllShape", true); + + KernelExTaskInfo kernel_ex_task_info; + EXPECT_NE(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); +} + +TEST_F(UtestKernelExTaskInfo, parse_topic_type_failed_2) { + const string ext_info = {7,0,0,0,2,0,0,0,2,0,0,0}; + const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); + AttrUtils::SetBool(op_desc, "_AllShape", true); + + KernelExTaskInfo kernel_ex_task_info; + EXPECT_NE(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); +} } // namespace ge diff --git a/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h b/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h index 7a2cbc50..df57c82e 100644 --- a/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h +++ b/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h @@ -61,9 +61,18 @@ enum FWKTaskExtInfoType { FWK_ADPT_EXT_OP_NAME, FWK_ADPT_EXT_SESSION_INFO, FWK_ADPT_EXT_BITMAP, + FWK_ADPT_EXT_TOPIC_TYPE, FWK_ADPT_EXT_INVALID }; +enum FWKExtTopicType { + FWK_ADPT_TOPIC_DEVICE_ONLY = 0, + FWK_ADPT_TOPIC_DEVICE_FIRST, + FWK_ADPT_TOPIC_HOST_ONLY, + FWK_ADPT_TOPIC_HOST_FIRST, + FWK_ADPT_TOPIC_INVALID +}; + enum FWKExtUpdateAddrType { FWK_ADPT_UPDATE_NULL = 0, FWK_ADPT_UPDATE_INPUT, diff --git a/third_party/fwkacllib/inc/runtime/kernel.h b/third_party/fwkacllib/inc/runtime/kernel.h index b4500e10..402fadef 100644 --- a/third_party/fwkacllib/inc/runtime/kernel.h +++ b/third_party/fwkacllib/inc/runtime/kernel.h @@ -191,6 +191,11 @@ typedef void (*rtCallback_t)(void *fnData); #define RT_FUSION_KERNEL_DUMPFLAG (0x04) #define RT_KERNEL_CUSTOM_AICPU (0x08) +// STARS topic scheduler sqe : topic_type +#define RT_KERNEL_DEVICE_FIRST (0X10) +#define RT_KERNEL_HOST_ONLY (0X20) +#define RT_KERNEL_HOST_FIRST (0X30) + /** * @ingroup rt_kernel * @brief kernel mode From 5f46a1faee97bc73e7febe49e1e979030eb9cdc5 Mon Sep 17 00:00:00 2001 From: lianghao Date: Thu, 22 Apr 2021 14:26:32 +0800 Subject: [PATCH 2/2] input_fp16_nodes in Case --- ge/graph/passes/multi_batch_clone_pass.cc | 7 ++ ge/graph/preprocess/graph_preprocess.cc | 117 ++++++++++++++---- .../preprocess/graph_preprocess_unittest.cc | 75 +++++++++++ 3 files changed, 178 insertions(+), 21 deletions(-) diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index 9e1fe80a..8d4bcb66 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -42,6 +42,7 @@ const std::string kMultiBatchConstNode = "ascend_mbatch_shape_const"; const std::string kMultiBatchMapIndexNode = "ascend_mbatch_shape_mapindex"; const std::string kMultiBatchNodePostfix = "_ascend_mbatch_batch_"; const char *const kGetNextName = "IteratorV2"; +const char *const kMbatchCaseName = "mbatch-switch-name"; } // namespace inline bool IsGetNextType(const NodePtr &node) { @@ -943,6 +944,12 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &node, size_t out_an } } (void)AttrUtils::SetListInt(node->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); + if (!AttrUtils::SetStr(node->GetOpDesc(), kMbatchCaseName, case_node_->GetName())) { + REPORT_CALL_ERROR("E19999", "Set Attr:%s to node:%s(%s) failed", + kMbatchCaseName, node->GetName().c_str(), node->GetType().c_str()); + GELOGE(INTERNAL_ERROR, "Failed to add switchn attr on data node %s", node->GetName().c_str()); + return INTERNAL_ERROR; + } GeTensorDesc tensor(NodeUtils::GetOutputDesc(*node, kDataOutIndex)); std::vector input_dims_str; diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index 4fb80646..e4f7f231 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -609,7 +609,7 @@ Status ModifyDataNetOutputFormatAndShape(OpDescPtr &op_desc, uint32_t index, For return SUCCESS; } -Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, NodePtr &switchn_node) { +Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, NodePtr &mbatch_node, int32_t &index) { is_dynamic_batch = false; std::string related_node_name; if (AttrUtils::GetStr(data_node->GetOpDesc(), kMbatchSwitchnName, related_node_name)) { @@ -620,13 +620,17 @@ Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, Node data_node->GetName().c_str()); return INTERNAL_ERROR; } - for (const NodePtr &next_node : data_node->GetOutNodes()) { - if (next_node->GetName() == related_node_name) { - switchn_node = next_node; + + auto out_data_nodes_anchors = data_node->GetOutDataNodesAndAnchors(); + for (const auto &out_data_node_anchor : out_data_nodes_anchors) { + if (out_data_node_anchor.first->GetName() == related_node_name) { + mbatch_node = out_data_node_anchor.first; + index = out_data_node_anchor.second->GetIdx(); break; } } - if (switchn_node == nullptr) { + + if (mbatch_node == nullptr) { ErrorManager::GetInstance().ATCReportErrMessage( "E15002", {"opname", "value", "reason"}, {data_node->GetName(), related_node_name, "but can not find it on the graph"}); @@ -679,7 +683,7 @@ Status CheckIfNeedSetNdFormat(const NodePtr &node_ptr) { // In the dynamic shape process, transnode insertion by FE is advanced to the stage of whole // graph optimization, GE only sets the final data_type/format/shape information for variable, // data and netoutput, and no longer inserts the transnode. -Status ProcessInputDtDynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr &switchn_node, DataType &dt_set) { +Status ProcessInputDtDynShape(NodePtr &node_ptr, NodePtr &switchn_node, DataType &dt_set) { GE_CHECK_NOTNULL(node_ptr); auto op_desc = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -712,19 +716,84 @@ Status ProcessInputDtDynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr GELOGI("[Process][InputDynShape] Set input and output size of node [%s] success.", node_ptr->GetName().c_str()); } - if (is_dynamic_batch) { - GELOGI("The node [%s] dtype set fp16", switchn_node->GetName().c_str()); - auto switchn_op_desc = switchn_node->GetOpDesc(); - GE_CHECK_NOTNULL(switchn_op_desc); - auto switchn_input = switchn_op_desc->MutableInputDesc(0); - GE_CHECK_NOTNULL(switchn_input); - switchn_input->SetDataType(dt_set); - for (uint32_t i = 0; i < switchn_node->GetAllOutDataAnchorsSize(); ++i) { - const GeTensorDescPtr &switchn_output = switchn_op_desc->MutableOutputDesc(i); - GE_CHECK_NOTNULL(switchn_output); - switchn_output->SetDataType(dt_set); + return SUCCESS; +} + +Status UpdateInputOutputDataType(NodePtr &mbatch_node, DataType &dt_set, int32_t index) { + auto mbatch_desc = mbatch_node->GetOpDesc(); + GE_CHECK_NOTNULL(mbatch_desc); + auto mbatch_input = mbatch_desc->MutableInputDesc(index); + GE_CHECK_NOTNULL(mbatch_input); + mbatch_input->SetDataType(dt_set); + + if (mbatch_node->GetType() == SWITCHN) { + for (uint32_t i = 0; i < mbatch_node->GetAllOutDataAnchorsSize(); ++i) { + const GeTensorDescPtr &mbatch_output = mbatch_desc->MutableOutputDesc(i); + GE_CHECK_NOTNULL(mbatch_output); + mbatch_output->SetDataType(dt_set); + } + } + + GELOGD("Update input and output data type of node[name: %s, type: %s, input index: %d] to %s.", + mbatch_node->GetName().c_str(), mbatch_node->GetType().c_str(), index, + TypeUtils::DataTypeToSerialString(dt_set).c_str()); + + return SUCCESS; +} + +Status UpdateSubgraphDataOfCase(NodePtr &mbatch_node, DataType &dt_set, int32_t index) { + if (mbatch_node->GetType() != CASE) { + return SUCCESS; + } + + auto subgraphs = NodeUtils::GetAllSubgraphs(*mbatch_node); + for (const auto &subgraph : subgraphs) { + GE_CHECK_NOTNULL(subgraph); + for (auto &sub_node : subgraph->GetDirectNode()) { + GE_CHECK_NOTNULL(sub_node); + if (sub_node->GetType() != DATA) { + continue; + } + + auto data_desc = sub_node->GetOpDesc(); + GE_CHECK_NOTNULL(data_desc); + int32_t parent_node_index = 0; + if (!AttrUtils::GetInt(data_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index) || + (parent_node_index != index)) { + continue; + } + + auto data_input = data_desc->MutableInputDesc(0); + GE_CHECK_NOTNULL(data_input); + data_input->SetDataType(dt_set); + auto data_output = data_desc->MutableOutputDesc(0); + GE_CHECK_NOTNULL(data_output); + data_output->SetDataType(dt_set); + GELOGD("Update input and output data type of node[name: %s, type: %s, parent_node_index: %d] in subgraph %s " + "to %s.", data_desc->GetName().c_str(), data_desc->GetType().c_str(), parent_node_index, + subgraph->GetName().c_str(), TypeUtils::DataTypeToSerialString(dt_set).c_str()); } } + + return SUCCESS; +} + +Status ProcessMbatchScene(NodePtr &mbatch_node, DataType &dt_set, int32_t index) { + GELOGI("The node [%s] dtype set fp16.", mbatch_node->GetName().c_str()); + if (UpdateInputOutputDataType(mbatch_node, dt_set, index) != SUCCESS) { + GELOGE(FAILED, "Update input and output data type of node[name: %s, type: %s] to %s failed.", + mbatch_node->GetName().c_str(), mbatch_node->GetType().c_str(), + TypeUtils::DataTypeToSerialString(dt_set).c_str()); + return FAILED; + } + + if (UpdateSubgraphDataOfCase(mbatch_node, dt_set, index) != SUCCESS) { + GELOGE(FAILED, "Update input and output data type of Data node[parent_node_index: %d] in subgraphs of " + "node[name: %s, type: %s] to %s failed.", index, mbatch_node->GetName().c_str(), + mbatch_node->GetType().c_str(), TypeUtils::DataTypeToSerialString(dt_set).c_str()); + return FAILED; + } + return SUCCESS; } @@ -785,21 +854,27 @@ Status ProcessDataNodeDynShape(NodePtr &node_ptr) { DataType dt_set = TypeUtils::SerialStringToDataType(set_dt_str); GELOGI("input_fp16 is found, the node name is %s.", node_ptr->GetName().c_str()); bool is_dynamic_batch = false; - NodePtr switchn_node = nullptr; - if (CheckIfDynamicBatchScene(node_ptr, is_dynamic_batch, switchn_node)) { + NodePtr mbatch_node = nullptr; + int32_t index = 0; + if (CheckIfDynamicBatchScene(node_ptr, is_dynamic_batch, mbatch_node, index)) { GELOGE(INTERNAL_ERROR, "CheckIfDynamicBatchScene failed"); return FAILED; } - if (ProcessInputDtDynShape(node_ptr, is_dynamic_batch, switchn_node, dt_set) != SUCCESS) { + if (ProcessInputDtDynShape(node_ptr, mbatch_node, dt_set) != SUCCESS) { GELOGE(INTERNAL_ERROR, "ProcessInputFP16 failed"); return FAILED; } + if (is_dynamic_batch && ProcessMbatchScene(mbatch_node, dt_set, index) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "ProcessMbatchScene failed"); + return FAILED; + } + // check if need to set format string set_format; bool ret = ge::AttrUtils::GetStr(node_ptr->GetOpDesc(), ATTR_ATC_USER_DEFINE_FORMAT, set_format); if (ret && (!set_format.empty()) && TypeUtils::SerialStringToFormat(set_format) == FORMAT_NC1HWC0) { GELOGI("The format of node [%s] should be set NC1HWC0.", node_ptr->GetName().c_str()); - if (ProcessInputNC1HWC0DynShape(node_ptr, is_dynamic_batch, switchn_node) != SUCCESS) { + if (ProcessInputNC1HWC0DynShape(node_ptr, is_dynamic_batch, mbatch_node) != SUCCESS) { GELOGE(INTERNAL_ERROR, "ProcessInputNC1HWC0 failed"); return FAILED; } diff --git a/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc b/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc index ff49f34c..8d0be31d 100644 --- a/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc +++ b/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc @@ -72,6 +72,41 @@ ComputeGraphPtr BuildGraph3() { return builder.GetGraph(); } +/* + * MapIndex Data1 subgraph1 subgraph2 + * \ / + * Case ===> Data2 Data3 + * | + * Netoutput + */ +ComputeGraphPtr BuildGraph4() { + auto builder = ut::GraphBuilder("mbatch_Case"); + + auto data1 = builder.AddNode("data1", DATA, 1, 1); + auto data_desc = data1->GetOpDesc(); + AttrUtils::SetStr(data_desc, ATTR_ATC_USER_DEFINE_DATATYPE, "DT_FLOAT16"); + AttrUtils::SetStr(data_desc, "mbatch-switch-name", "case1"); + AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, 0); + + auto mapindex1 = builder.AddNode("mapindex1", "MapIndex", 0, 1); + auto case1 = builder.AddNode("case1", CASE, 2, 1); + auto netoutput1 = builder.AddNode("netoutput1", NETOUTPUT, 1, 0); + + builder.AddDataEdge(mapindex1, 0, case1, 0); + builder.AddDataEdge(data1, 0, case1, 1); + builder.AddDataEdge(case1, 0, netoutput1, 0); + + return builder.GetGraph(); +} + +ComputeGraphPtr BuildGraph4_Subgraph(string graph_name) { + auto builder = ut::GraphBuilder(graph_name); + auto data1 = builder.AddNode(graph_name + "_data1", DATA, 1, 1); + auto data_desc = data1->GetOpDesc(); + AttrUtils::SetInt(data_desc, ATTR_NAME_PARENT_NODE_INDEX, 1); + return builder.GetGraph(); +} + TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) { ge::GraphPrepare graph_prepare; graph_prepare.compute_graph_ = BuildGraph1(); @@ -118,4 +153,44 @@ TEST_F(UtestGraphPreproces, test_update_input_output1) { Status ret = graph_prepare.UpdateInputOutputByOptions(); EXPECT_EQ(ret, SUCCESS); } + +TEST_F(UtestGraphPreproces, test_update_dtype_mbatch_case) { + ge::GraphPrepare graph_prepare; + graph_prepare.compute_graph_ = BuildGraph4(); + auto parent_graph = graph_prepare.compute_graph_; + auto subgraph1 = BuildGraph4_Subgraph("subgraph1"); + auto subgraph2 = BuildGraph4_Subgraph("subgraph2"); + + auto data1 = parent_graph->FindNode("data1"); + auto data_desc = data1->GetOpDesc(); + + auto case_node = parent_graph->FindNode("case1"); + EXPECT_NE(case_node, nullptr); + case_node->GetOpDesc()->AddSubgraphName("subgraph1"); + case_node->GetOpDesc()->SetSubgraphInstanceName(0, "subgraph1"); + subgraph1->SetParentNode(case_node); + subgraph1->SetParentGraph(parent_graph); + EXPECT_EQ(parent_graph->AddSubgraph("subgraph1", subgraph1), GRAPH_SUCCESS); + + case_node->GetOpDesc()->AddSubgraphName("subgraph2"); + case_node->GetOpDesc()->SetSubgraphInstanceName(1, "subgraph2"); + subgraph2->SetParentNode(case_node); + subgraph2->SetParentGraph(parent_graph); + EXPECT_EQ(parent_graph->AddSubgraph("subgraph2", subgraph2), GRAPH_SUCCESS); + + Status ret = graph_prepare.UpdateInputOutputByOptions(); + EXPECT_EQ(ret, SUCCESS); + + auto case_desc = case_node->GetOpDesc(); + auto case_input = case_desc->MutableInputDesc(1); + EXPECT_EQ(case_input->GetDataType(), 1); + + auto sub1_data1 = subgraph1->FindNode("subgraph1_data1"); + EXPECT_NE(sub1_data1, nullptr); + auto data1_desc = sub1_data1->GetOpDesc(); + auto data1_input = data1_desc->MutableInputDesc(0); + EXPECT_EQ(data1_input->GetDataType(), 1); + auto data1_output = data1_desc->MutableOutputDesc(0); + EXPECT_EQ(data1_output->GetDataType(), 1); +} } \ No newline at end of file