| @@ -53,6 +53,7 @@ Status KernelExTaskInfo::InitTaskExtInfo(const std::string &ext_info, const OpDe | |||
| "Parse kernel ext info failed, kernel_ext_info_size=%zu.", ext_info.size()); | |||
| GE_CHK_STATUS_RET(ext_handle->UpdateExecuteMode(true), "UpdateExecuteMode failed."); | |||
| GELOGD("Update aicpu_task ext_info bit_map execute mode to 1."); | |||
| topic_type_flag_ = ext_handle->GetTopicTypeFlag(); | |||
| bool all_shape = false; | |||
| (void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape); | |||
| @@ -407,6 +408,14 @@ Status KernelExTaskInfo::CopyTaskInfo(const domi::KernelExDef &kernel_def, const | |||
| Status KernelExTaskInfo::Distribute() { | |||
| GELOGI("KernelExTaskInfo Distribute Start."); | |||
| // Use the fifth and sixth bits of dump_flag_ indicate the value of topic_type. | |||
| // xxxxxxxx xxxxxxxx xxxxxxxx xx00xxxx: DEVICE_ONLY | |||
| // xxxxxxxx xxxxxxxx xxxxxxxx xx01xxxx: DEVICE_FIRST | |||
| // xxxxxxxx xxxxxxxx xxxxxxxx xx10xxxx: HOST_ONLY | |||
| // xxxxxxxx xxxxxxxx xxxxxxxx xx11xxxx: HOST_FIRST | |||
| if (topic_type_flag_ > 0) { | |||
| dump_flag_ = dump_flag_ | topic_type_flag_; | |||
| } | |||
| rtError_t rt_ret = rtKernelLaunchEx(kernel_buf_, kernel_buf_size_, dump_flag_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| REPORT_CALL_ERROR("E19999", "Call rtKernelLaunchEx failed, ret:0x%X", | |||
| @@ -76,6 +76,7 @@ class KernelExTaskInfo : public TaskInfo { | |||
| vector<void *> io_addrs_; | |||
| uint32_t args_offset_ = 0; | |||
| int64_t fixed_addr_offset_ = 0; | |||
| int32_t topic_type_flag_ = -1; | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_KERNEL_EX_TASK_INFO_H_ | |||
| @@ -431,6 +431,14 @@ Status KernelTaskInfo::Distribute() { | |||
| int64_t env_flag = (res == EN_OK) ? strtol(skt_enable_env, nullptr, kBaseInt) : kStrtolFail; | |||
| bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); | |||
| if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) { | |||
| if (topic_type_flag_ > 0) { | |||
| // Use the fifth and sixth bits of dump_flag_ indicate the value of topic_type. | |||
| // xxxxxxxx xxxxxxxx xxxxxxxx xx00xxxx: DEVICE_ONLY | |||
| // xxxxxxxx xxxxxxxx xxxxxxxx xx01xxxx: DEVICE_FIRST | |||
| // xxxxxxxx xxxxxxxx xxxxxxxx xx10xxxx: HOST_ONLY | |||
| // xxxxxxxx xxxxxxxx xxxxxxxx xx11xxxx: HOST_FIRST | |||
| dump_flag_ = dump_flag_ | topic_type_flag_; | |||
| } | |||
| GELOGI("distribute task info kernel_type %d, flag %d", kernel_type_, dump_flag_); | |||
| // blockDim is reserved parameter, set to 1 | |||
| rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(so_name_.c_str()), | |||
| @@ -1116,6 +1124,7 @@ Status KernelTaskInfo::InitAicpuTaskExtInfo(const std::string &ext_info) { | |||
| GELOGD("Update aicpu_task ext_info session_info session_id is %lu", davinci_model_->GetSessionId()); | |||
| GE_CHK_STATUS_RET(ext_handle->UpdateExecuteMode(true), "UpdateExecuteMode failed."); | |||
| GELOGD("Update aicpu_task ext_info bit_map execute mode to 1."); | |||
| topic_type_flag_ = ext_handle->GetTopicTypeFlag(); | |||
| bool all_shape = false; | |||
| (void)AttrUtils::GetBool(op_desc_, kAicpuAllshape, all_shape); | |||
| @@ -169,6 +169,7 @@ class KernelTaskInfo : public TaskInfo { | |||
| uint16_t io_addr_offset_ = 0; | |||
| bool l2_buffer_on_ = false; | |||
| bool call_save_dump_ = false; | |||
| int32_t topic_type_flag_ = -1; | |||
| // aicpu ext_info device mem | |||
| void *aicpu_ext_info_addr_ = nullptr; | |||
| @@ -42,6 +42,7 @@ const std::string kMultiBatchConstNode = "ascend_mbatch_shape_const"; | |||
| const std::string kMultiBatchMapIndexNode = "ascend_mbatch_shape_mapindex"; | |||
| const std::string kMultiBatchNodePostfix = "_ascend_mbatch_batch_"; | |||
| const char *const kGetNextName = "IteratorV2"; | |||
| const char *const kMbatchCaseName = "mbatch-switch-name"; | |||
| } // namespace | |||
| inline bool IsGetNextType(const NodePtr &node) { | |||
| @@ -943,6 +944,12 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &node, size_t out_an | |||
| } | |||
| } | |||
| (void)AttrUtils::SetListInt(node->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); | |||
| if (!AttrUtils::SetStr(node->GetOpDesc(), kMbatchCaseName, case_node_->GetName())) { | |||
| REPORT_CALL_ERROR("E19999", "Set Attr:%s to node:%s(%s) failed", | |||
| kMbatchCaseName, node->GetName().c_str(), node->GetType().c_str()); | |||
| GELOGE(INTERNAL_ERROR, "Failed to add switchn attr on data node %s", node->GetName().c_str()); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| GeTensorDesc tensor(NodeUtils::GetOutputDesc(*node, kDataOutIndex)); | |||
| std::vector<std::string> input_dims_str; | |||
| @@ -610,7 +610,7 @@ Status ModifyDataNetOutputFormatAndShape(OpDescPtr &op_desc, uint32_t index, For | |||
| return SUCCESS; | |||
| } | |||
| Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, NodePtr &switchn_node) { | |||
| Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, NodePtr &mbatch_node, int32_t &index) { | |||
| is_dynamic_batch = false; | |||
| std::string related_node_name; | |||
| if (AttrUtils::GetStr(data_node->GetOpDesc(), kMbatchSwitchnName, related_node_name)) { | |||
| @@ -621,13 +621,17 @@ Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, Node | |||
| data_node->GetName().c_str()); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| for (const NodePtr &next_node : data_node->GetOutNodes()) { | |||
| if (next_node->GetName() == related_node_name) { | |||
| switchn_node = next_node; | |||
| auto out_data_nodes_anchors = data_node->GetOutDataNodesAndAnchors(); | |||
| for (const auto &out_data_node_anchor : out_data_nodes_anchors) { | |||
| if (out_data_node_anchor.first->GetName() == related_node_name) { | |||
| mbatch_node = out_data_node_anchor.first; | |||
| index = out_data_node_anchor.second->GetIdx(); | |||
| break; | |||
| } | |||
| } | |||
| if (switchn_node == nullptr) { | |||
| if (mbatch_node == nullptr) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E15002", {"opname", "value", "reason"}, | |||
| {data_node->GetName(), related_node_name, "but can not find it on the graph"}); | |||
| @@ -680,7 +684,7 @@ Status CheckIfNeedSetNdFormat(const NodePtr &node_ptr) { | |||
| // In the dynamic shape process, transnode insertion by FE is advanced to the stage of whole | |||
| // graph optimization, GE only sets the final data_type/format/shape information for variable, | |||
| // data and netoutput, and no longer inserts the transnode. | |||
| Status ProcessInputDtDynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr &switchn_node, DataType &dt_set) { | |||
| Status ProcessInputDtDynShape(NodePtr &node_ptr, NodePtr &switchn_node, DataType &dt_set) { | |||
| GE_CHECK_NOTNULL(node_ptr); | |||
| auto op_desc = node_ptr->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| @@ -713,19 +717,84 @@ Status ProcessInputDtDynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr | |||
| GELOGI("[Process][InputDynShape] Set input and output size of node [%s] success.", node_ptr->GetName().c_str()); | |||
| } | |||
| if (is_dynamic_batch) { | |||
| GELOGI("The node [%s] dtype set fp16", switchn_node->GetName().c_str()); | |||
| auto switchn_op_desc = switchn_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(switchn_op_desc); | |||
| auto switchn_input = switchn_op_desc->MutableInputDesc(0); | |||
| GE_CHECK_NOTNULL(switchn_input); | |||
| switchn_input->SetDataType(dt_set); | |||
| for (uint32_t i = 0; i < switchn_node->GetAllOutDataAnchorsSize(); ++i) { | |||
| const GeTensorDescPtr &switchn_output = switchn_op_desc->MutableOutputDesc(i); | |||
| GE_CHECK_NOTNULL(switchn_output); | |||
| switchn_output->SetDataType(dt_set); | |||
| return SUCCESS; | |||
| } | |||
| Status UpdateInputOutputDataType(NodePtr &mbatch_node, DataType &dt_set, int32_t index) { | |||
| auto mbatch_desc = mbatch_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(mbatch_desc); | |||
| auto mbatch_input = mbatch_desc->MutableInputDesc(index); | |||
| GE_CHECK_NOTNULL(mbatch_input); | |||
| mbatch_input->SetDataType(dt_set); | |||
| if (mbatch_node->GetType() == SWITCHN) { | |||
| for (uint32_t i = 0; i < mbatch_node->GetAllOutDataAnchorsSize(); ++i) { | |||
| const GeTensorDescPtr &mbatch_output = mbatch_desc->MutableOutputDesc(i); | |||
| GE_CHECK_NOTNULL(mbatch_output); | |||
| mbatch_output->SetDataType(dt_set); | |||
| } | |||
| } | |||
| GELOGD("Update input and output data type of node[name: %s, type: %s, input index: %d] to %s.", | |||
| mbatch_node->GetName().c_str(), mbatch_node->GetType().c_str(), index, | |||
| TypeUtils::DataTypeToSerialString(dt_set).c_str()); | |||
| return SUCCESS; | |||
| } | |||
| Status UpdateSubgraphDataOfCase(NodePtr &mbatch_node, DataType &dt_set, int32_t index) { | |||
| if (mbatch_node->GetType() != CASE) { | |||
| return SUCCESS; | |||
| } | |||
| auto subgraphs = NodeUtils::GetAllSubgraphs(*mbatch_node); | |||
| for (const auto &subgraph : subgraphs) { | |||
| GE_CHECK_NOTNULL(subgraph); | |||
| for (auto &sub_node : subgraph->GetDirectNode()) { | |||
| GE_CHECK_NOTNULL(sub_node); | |||
| if (sub_node->GetType() != DATA) { | |||
| continue; | |||
| } | |||
| auto data_desc = sub_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(data_desc); | |||
| int32_t parent_node_index = 0; | |||
| if (!AttrUtils::GetInt(data_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index) || | |||
| (parent_node_index != index)) { | |||
| continue; | |||
| } | |||
| auto data_input = data_desc->MutableInputDesc(0); | |||
| GE_CHECK_NOTNULL(data_input); | |||
| data_input->SetDataType(dt_set); | |||
| auto data_output = data_desc->MutableOutputDesc(0); | |||
| GE_CHECK_NOTNULL(data_output); | |||
| data_output->SetDataType(dt_set); | |||
| GELOGD("Update input and output data type of node[name: %s, type: %s, parent_node_index: %d] in subgraph %s " | |||
| "to %s.", data_desc->GetName().c_str(), data_desc->GetType().c_str(), parent_node_index, | |||
| subgraph->GetName().c_str(), TypeUtils::DataTypeToSerialString(dt_set).c_str()); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status ProcessMbatchScene(NodePtr &mbatch_node, DataType &dt_set, int32_t index) { | |||
| GELOGI("The node [%s] dtype set fp16.", mbatch_node->GetName().c_str()); | |||
| if (UpdateInputOutputDataType(mbatch_node, dt_set, index) != SUCCESS) { | |||
| GELOGE(FAILED, "Update input and output data type of node[name: %s, type: %s] to %s failed.", | |||
| mbatch_node->GetName().c_str(), mbatch_node->GetType().c_str(), | |||
| TypeUtils::DataTypeToSerialString(dt_set).c_str()); | |||
| return FAILED; | |||
| } | |||
| if (UpdateSubgraphDataOfCase(mbatch_node, dt_set, index) != SUCCESS) { | |||
| GELOGE(FAILED, "Update input and output data type of Data node[parent_node_index: %d] in subgraphs of " | |||
| "node[name: %s, type: %s] to %s failed.", index, mbatch_node->GetName().c_str(), | |||
| mbatch_node->GetType().c_str(), TypeUtils::DataTypeToSerialString(dt_set).c_str()); | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -786,21 +855,27 @@ Status ProcessDataNodeDynShape(NodePtr &node_ptr) { | |||
| DataType dt_set = TypeUtils::SerialStringToDataType(set_dt_str); | |||
| GELOGI("input_fp16 is found, the node name is %s.", node_ptr->GetName().c_str()); | |||
| bool is_dynamic_batch = false; | |||
| NodePtr switchn_node = nullptr; | |||
| if (CheckIfDynamicBatchScene(node_ptr, is_dynamic_batch, switchn_node)) { | |||
| NodePtr mbatch_node = nullptr; | |||
| int32_t index = 0; | |||
| if (CheckIfDynamicBatchScene(node_ptr, is_dynamic_batch, mbatch_node, index)) { | |||
| GELOGE(INTERNAL_ERROR, "CheckIfDynamicBatchScene failed"); | |||
| return FAILED; | |||
| } | |||
| if (ProcessInputDtDynShape(node_ptr, is_dynamic_batch, switchn_node, dt_set) != SUCCESS) { | |||
| if (ProcessInputDtDynShape(node_ptr, mbatch_node, dt_set) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "ProcessInputFP16 failed"); | |||
| return FAILED; | |||
| } | |||
| if (is_dynamic_batch && ProcessMbatchScene(mbatch_node, dt_set, index) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "ProcessMbatchScene failed"); | |||
| return FAILED; | |||
| } | |||
| // check if need to set format | |||
| string set_format; | |||
| bool ret = ge::AttrUtils::GetStr(node_ptr->GetOpDesc(), ATTR_ATC_USER_DEFINE_FORMAT, set_format); | |||
| if (ret && (!set_format.empty()) && TypeUtils::SerialStringToFormat(set_format) == FORMAT_NC1HWC0) { | |||
| GELOGI("The format of node [%s] should be set NC1HWC0.", node_ptr->GetName().c_str()); | |||
| if (ProcessInputNC1HWC0DynShape(node_ptr, is_dynamic_batch, switchn_node) != SUCCESS) { | |||
| if (ProcessInputNC1HWC0DynShape(node_ptr, is_dynamic_batch, mbatch_node) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "ProcessInputNC1HWC0 failed"); | |||
| return FAILED; | |||
| } | |||
| @@ -24,6 +24,12 @@ namespace hybrid { | |||
| namespace { | |||
| // if dim count is not reach kMaxShapeDims(8), use INT64_MIN to mark dim end. | |||
| constexpr int64_t kDimEndFlag = INT64_MIN; | |||
| const std::map<int32_t, int32_t> kTopicTypeToRtsFlagMap { | |||
| {static_cast<int32_t>(aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_ONLY), 0}, | |||
| {static_cast<int32_t>(aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_FIRST), RT_KERNEL_DEVICE_FIRST}, | |||
| {static_cast<int32_t>(aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_ONLY), RT_KERNEL_HOST_ONLY}, | |||
| {static_cast<int32_t>(aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_FIRST), RT_KERNEL_HOST_FIRST} | |||
| }; | |||
| } | |||
| Status AicpuExtInfoHandler::Parse(const std::string &ext_info) { | |||
| @@ -72,6 +78,9 @@ Status AicpuExtInfoHandler::Parse(const std::string &ext_info) { | |||
| case aicpu::FWKAdapter::FWK_ADPT_EXT_UPDATE_ADDR: | |||
| GE_CHK_STATUS_RET(ParseExtUpdateAddr(aicpu_ext_info), "[Parse][ExtUpdateAddr] failed."); | |||
| break; | |||
| case aicpu::FWKAdapter::FWK_ADPT_EXT_TOPIC_TYPE: | |||
| GE_CHK_STATUS_RET(ParseExtTopicType(aicpu_ext_info), "[Parse][ExtTopicType] failed."); | |||
| break; | |||
| default: | |||
| GELOGD("Node[%s] ignore infoType=%d, infoLen=%u.", | |||
| node_name_.c_str(), aicpu_ext_info->infoType, aicpu_ext_info->infoLen); | |||
| @@ -207,6 +216,44 @@ Status AicpuExtInfoHandler::ParseExtUpdateAddr(AicpuExtInfo *aicpu_ext_info) { | |||
| return SUCCESS; | |||
| } | |||
| Status AicpuExtInfoHandler::ParseExtTopicType(AicpuExtInfo *aicpu_ext_info) { | |||
| if (aicpu_ext_info->infoLen != sizeof(int32_t)) { | |||
| REPORT_INNER_ERROR("E19999", | |||
| "Node[%s] parse topic_type info failed as infoLen must be %zu but %u.", | |||
| node_name_.c_str(), sizeof(int32_t), aicpu_ext_info->infoLen); | |||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, | |||
| "[Check][DataLen]Node[%s] parse topic_type info failed as infoLen must be %zu but %u.", | |||
| node_name_.c_str(), sizeof(int32_t), aicpu_ext_info->infoLen); | |||
| return ACL_ERROR_GE_PARAM_INVALID; | |||
| } | |||
| GE_CHECK_NOTNULL(aicpu_ext_info->infoMsg); | |||
| auto type = *reinterpret_cast<const int32_t *>(aicpu_ext_info->infoMsg); | |||
| topic_type_flag_ = TopicTypeToRtsFlag(type); | |||
| if (topic_type_flag_ == -1) { | |||
| REPORT_INNER_ERROR("E19999", "Node[%s] parse ext topic type failed as need %d %d %d %d but %d.", | |||
| node_name_.c_str(), | |||
| aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_ONLY, | |||
| aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_FIRST, | |||
| aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_ONLY, | |||
| aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_FIRST, | |||
| type); | |||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, | |||
| "[Check][Type]Node[%s] parse ext shape type failed as need %d %d %d %d but %d.", | |||
| node_name_.c_str(), | |||
| aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_ONLY, | |||
| aicpu::FWKAdapter::FWK_ADPT_TOPIC_DEVICE_FIRST, | |||
| aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_ONLY, | |||
| aicpu::FWKAdapter::FWK_ADPT_TOPIC_HOST_FIRST, | |||
| type); | |||
| return ACL_ERROR_GE_PARAM_INVALID; | |||
| } | |||
| GELOGI("Node[%s] parse ext topic type info success infoLen=%u, topic_type=%d, rts_flag=%d.", | |||
| node_name_.c_str(), aicpu_ext_info->infoLen, type, topic_type_flag_); | |||
| return SUCCESS; | |||
| } | |||
| Status AicpuExtInfoHandler::UpdateExecuteMode(bool flag) { | |||
| if (bit_map_ == nullptr) { | |||
| GELOGD("There is no bit_map in ext_info, no need update."); | |||
| @@ -341,5 +388,14 @@ void AicpuExtInfoHandler::GetShapeAndType(const AicpuShapeAndType *shape_and_typ | |||
| data_type = static_cast<DataType>(shape_and_type->type); | |||
| shape = GeShape(dims); | |||
| } | |||
| int32_t AicpuExtInfoHandler::TopicTypeToRtsFlag(int32_t topic_type) { | |||
| auto it = kTopicTypeToRtsFlagMap.find(topic_type); | |||
| if (it != kTopicTypeToRtsFlagMap.end()) { | |||
| return it->second; | |||
| } | |||
| return -1; | |||
| } | |||
| } // namespace hybrid | |||
| } // namespace ge | |||
| @@ -62,6 +62,7 @@ class AicpuExtInfoHandler { | |||
| Status GetOutputShapeAndType(uint32_t output_index, GeShape &shape, DataType &data_type); | |||
| bool IsNeedRefreshIOAddr(); | |||
| int32_t GetTopicTypeFlag() const { return topic_type_flag_; } | |||
| private: | |||
| @@ -71,6 +72,7 @@ class AicpuExtInfoHandler { | |||
| Status ParseExtSessionInfo(AicpuExtInfo *aicpu_ext_info); | |||
| Status ParseExtBitMap(AicpuExtInfo *aicpu_ext_info); | |||
| Status ParseExtUpdateAddr(AicpuExtInfo *aicpu_ext_info); | |||
| Status ParseExtTopicType(AicpuExtInfo *aicpu_ext_info); | |||
| static Status UpdateShapeAndType(const GeShape &shape, | |||
| DataType data_type, | |||
| @@ -81,6 +83,8 @@ class AicpuExtInfoHandler { | |||
| DataType &data_type); | |||
| private: | |||
| int32_t TopicTypeToRtsFlag(int32_t topic_type); | |||
| const std::string node_name_; | |||
| const uint32_t input_num_; | |||
| const uint32_t output_num_; | |||
| @@ -88,6 +92,7 @@ class AicpuExtInfoHandler { | |||
| AicpuSessionInfo *session_info_ = nullptr; | |||
| uint64_t *bit_map_ = nullptr; | |||
| uint32_t *update_addr_ = nullptr; | |||
| int32_t topic_type_flag_ = -1; | |||
| std::unique_ptr<uint8_t[]> ext_info_; | |||
| size_t ext_info_len_ = 0; | |||
| @@ -155,4 +155,58 @@ TEST_F(UtestKernelExTaskInfo, parse_update_addr) { | |||
| KernelExTaskInfo kernel_ex_task_info; | |||
| EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); | |||
| } | |||
| TEST_F(UtestKernelExTaskInfo, parse_topic_type_success_1) { | |||
| const string ext_info = {7,0,0,0,4,0,0,0,0,0,0,0}; | |||
| const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); | |||
| AttrUtils::SetBool(op_desc, "_AllShape", true); | |||
| KernelExTaskInfo kernel_ex_task_info; | |||
| EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); | |||
| } | |||
| TEST_F(UtestKernelExTaskInfo, parse_topic_type_success_2) { | |||
| const string ext_info = {7,0,0,0,4,0,0,0,1,0,0,0}; | |||
| const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); | |||
| AttrUtils::SetBool(op_desc, "_AllShape", true); | |||
| KernelExTaskInfo kernel_ex_task_info; | |||
| EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); | |||
| } | |||
| TEST_F(UtestKernelExTaskInfo, parse_topic_type_success_3) { | |||
| const string ext_info = {7,0,0,0,4,0,0,0,2,0,0,0}; | |||
| const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); | |||
| AttrUtils::SetBool(op_desc, "_AllShape", true); | |||
| KernelExTaskInfo kernel_ex_task_info; | |||
| EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); | |||
| } | |||
| TEST_F(UtestKernelExTaskInfo, parse_topic_type_success_4) { | |||
| const string ext_info = {7,0,0,0,4,0,0,0,3,0,0,0}; | |||
| const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); | |||
| AttrUtils::SetBool(op_desc, "_AllShape", true); | |||
| KernelExTaskInfo kernel_ex_task_info; | |||
| EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); | |||
| } | |||
| TEST_F(UtestKernelExTaskInfo, parse_topic_type_failed_1) { | |||
| const string ext_info = {7,0,0,0,4,0,0,0,4,0,0,0}; | |||
| const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); | |||
| AttrUtils::SetBool(op_desc, "_AllShape", true); | |||
| KernelExTaskInfo kernel_ex_task_info; | |||
| EXPECT_NE(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); | |||
| } | |||
| TEST_F(UtestKernelExTaskInfo, parse_topic_type_failed_2) { | |||
| const string ext_info = {7,0,0,0,2,0,0,0,2,0,0,0}; | |||
| const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); | |||
| AttrUtils::SetBool(op_desc, "_AllShape", true); | |||
| KernelExTaskInfo kernel_ex_task_info; | |||
| EXPECT_NE(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); | |||
| } | |||
| } // namespace ge | |||
| @@ -72,8 +72,8 @@ ComputeGraphPtr BuildGraph3() { | |||
| return builder.GetGraph(); | |||
| } | |||
| ComputeGraphPtr BuildGraph4() { | |||
| auto builder = ut::GraphBuilder("g4"); | |||
| ComputeGraphPtr BuildGraph5() { | |||
| auto builder = ut::GraphBuilder("g5"); | |||
| auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3}); | |||
| auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10}); | |||
| auto add = builder.AddNode("add", ADD, 2, 1); | |||
| @@ -85,6 +85,41 @@ ComputeGraphPtr BuildGraph4() { | |||
| return builder.GetGraph(); | |||
| } | |||
| /* | |||
| * MapIndex Data1 subgraph1 subgraph2 | |||
| * \ / | |||
| * Case ===> Data2 Data3 | |||
| * | | |||
| * Netoutput | |||
| */ | |||
| ComputeGraphPtr BuildGraph4() { | |||
| auto builder = ut::GraphBuilder("mbatch_Case"); | |||
| auto data1 = builder.AddNode("data1", DATA, 1, 1); | |||
| auto data_desc = data1->GetOpDesc(); | |||
| AttrUtils::SetStr(data_desc, ATTR_ATC_USER_DEFINE_DATATYPE, "DT_FLOAT16"); | |||
| AttrUtils::SetStr(data_desc, "mbatch-switch-name", "case1"); | |||
| AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, 0); | |||
| auto mapindex1 = builder.AddNode("mapindex1", "MapIndex", 0, 1); | |||
| auto case1 = builder.AddNode("case1", CASE, 2, 1); | |||
| auto netoutput1 = builder.AddNode("netoutput1", NETOUTPUT, 1, 0); | |||
| builder.AddDataEdge(mapindex1, 0, case1, 0); | |||
| builder.AddDataEdge(data1, 0, case1, 1); | |||
| builder.AddDataEdge(case1, 0, netoutput1, 0); | |||
| return builder.GetGraph(); | |||
| } | |||
| ComputeGraphPtr BuildGraph4_Subgraph(string graph_name) { | |||
| auto builder = ut::GraphBuilder(graph_name); | |||
| auto data1 = builder.AddNode(graph_name + "_data1", DATA, 1, 1); | |||
| auto data_desc = data1->GetOpDesc(); | |||
| AttrUtils::SetInt(data_desc, ATTR_NAME_PARENT_NODE_INDEX, 1); | |||
| return builder.GetGraph(); | |||
| } | |||
| TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) { | |||
| ge::GraphPrepare graph_prepare; | |||
| graph_prepare.compute_graph_ = BuildGraph1(); | |||
| @@ -135,7 +170,7 @@ TEST_F(UtestGraphPreproces, test_update_input_output1) { | |||
| TEST_F(UtestGraphPreproces, check_ref_op_data_succ) { | |||
| GraphPrepare graph_preparer; | |||
| ComputeGraphPtr graph_test = BuildGraph4(); | |||
| ComputeGraphPtr graph_test = BuildGraph5(); | |||
| NodePtr add_node = nullptr; | |||
| for (auto &node : graph_test->GetAllNodes()) { | |||
| if (node->GetName() == "add") { | |||
| @@ -149,4 +184,43 @@ TEST_F(UtestGraphPreproces, check_ref_op_data_succ) { | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| } | |||
| TEST_F(UtestGraphPreproces, test_update_dtype_mbatch_case) { | |||
| ge::GraphPrepare graph_prepare; | |||
| graph_prepare.compute_graph_ = BuildGraph4(); | |||
| auto parent_graph = graph_prepare.compute_graph_; | |||
| auto subgraph1 = BuildGraph4_Subgraph("subgraph1"); | |||
| auto subgraph2 = BuildGraph4_Subgraph("subgraph2"); | |||
| auto data1 = parent_graph->FindNode("data1"); | |||
| auto data_desc = data1->GetOpDesc(); | |||
| auto case_node = parent_graph->FindNode("case1"); | |||
| EXPECT_NE(case_node, nullptr); | |||
| case_node->GetOpDesc()->AddSubgraphName("subgraph1"); | |||
| case_node->GetOpDesc()->SetSubgraphInstanceName(0, "subgraph1"); | |||
| subgraph1->SetParentNode(case_node); | |||
| subgraph1->SetParentGraph(parent_graph); | |||
| EXPECT_EQ(parent_graph->AddSubgraph("subgraph1", subgraph1), GRAPH_SUCCESS); | |||
| case_node->GetOpDesc()->AddSubgraphName("subgraph2"); | |||
| case_node->GetOpDesc()->SetSubgraphInstanceName(1, "subgraph2"); | |||
| subgraph2->SetParentNode(case_node); | |||
| subgraph2->SetParentGraph(parent_graph); | |||
| EXPECT_EQ(parent_graph->AddSubgraph("subgraph2", subgraph2), GRAPH_SUCCESS); | |||
| Status ret = graph_prepare.UpdateInputOutputByOptions(); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| auto case_desc = case_node->GetOpDesc(); | |||
| auto case_input = case_desc->MutableInputDesc(1); | |||
| EXPECT_EQ(case_input->GetDataType(), 1); | |||
| auto sub1_data1 = subgraph1->FindNode("subgraph1_data1"); | |||
| EXPECT_NE(sub1_data1, nullptr); | |||
| auto data1_desc = sub1_data1->GetOpDesc(); | |||
| auto data1_input = data1_desc->MutableInputDesc(0); | |||
| EXPECT_EQ(data1_input->GetDataType(), 1); | |||
| auto data1_output = data1_desc->MutableOutputDesc(0); | |||
| EXPECT_EQ(data1_output->GetDataType(), 1); | |||
| } | |||
| } | |||
| @@ -61,9 +61,18 @@ enum FWKTaskExtInfoType { | |||
| FWK_ADPT_EXT_OP_NAME, | |||
| FWK_ADPT_EXT_SESSION_INFO, | |||
| FWK_ADPT_EXT_BITMAP, | |||
| FWK_ADPT_EXT_TOPIC_TYPE, | |||
| FWK_ADPT_EXT_INVALID | |||
| }; | |||
| enum FWKExtTopicType { | |||
| FWK_ADPT_TOPIC_DEVICE_ONLY = 0, | |||
| FWK_ADPT_TOPIC_DEVICE_FIRST, | |||
| FWK_ADPT_TOPIC_HOST_ONLY, | |||
| FWK_ADPT_TOPIC_HOST_FIRST, | |||
| FWK_ADPT_TOPIC_INVALID | |||
| }; | |||
| enum FWKExtUpdateAddrType { | |||
| FWK_ADPT_UPDATE_NULL = 0, | |||
| FWK_ADPT_UPDATE_INPUT, | |||
| @@ -191,6 +191,11 @@ typedef void (*rtCallback_t)(void *fnData); | |||
| #define RT_FUSION_KERNEL_DUMPFLAG (0x04) | |||
| #define RT_KERNEL_CUSTOM_AICPU (0x08) | |||
| // STARS topic scheduler sqe : topic_type | |||
| #define RT_KERNEL_DEVICE_FIRST (0X10) | |||
| #define RT_KERNEL_HOST_ONLY (0X20) | |||
| #define RT_KERNEL_HOST_FIRST (0X30) | |||
| /** | |||
| * @ingroup rt_kernel | |||
| * @brief kernel mode | |||