@@ -922,7 +922,7 @@ Status HybridModelBuilder::InitWeights() {
}
Status HybridModelBuilder::LoadTasks() {
GE_CHK_STATUS_RET(CheckAicpuOp(), "Check Aicpu op failed.");
GE_CHK_STATUS_RET(CheckAicpuOpList (), "Check Aicpu op failed.");
for (auto &it : hybrid_model_.node_items_) {
auto &node_item = it.second;
auto &node_ptr = node_item->node;
@@ -1560,21 +1560,20 @@ Status HybridModelBuilder::BuildInputMapping(GraphItem &graph_item,
return SUCCESS;
}
Status HybridModelBuilder::CheckAicpuOp() {
Status HybridModelBuilder::CheckAicpuOpList () {
std::vector<std::string> aicpu_optype_list;
std::vector<std::string> aicpu_tf_optype_list;
std::set<std::string> aicpu_optype_set;
std::set<std::string> aicpu_tf_optype_set;
const auto &root_graph = ge_root_model_->GetRootGraph();
for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) {
auto &name = it.first;
auto &ge_model = it.second;
GE_CHECK_NOTNULL(ge_model);
if (ge::AttrUtils::GetListStr(ge_model, "needCheckCpu", aicpu_optype_list)) {
if (ge::AttrUtils::GetListStr(* ge_model, "needCheckCpu", aicpu_optype_list)) {
aicpu_optype_set.insert(aicpu_optype_list.begin(), aicpu_optype_list.end());
}
if (ge::AttrUtils::GetListStr(ge_model, "needCheckTf", aicpu_tf_optype_list)) {
if (ge::AttrUtils::GetListStr(* ge_model, "needCheckTf", aicpu_tf_optype_list)) {
aicpu_tf_optype_set.insert(aicpu_tf_optype_list.begin(), aicpu_tf_optype_list.end());
}
}
@@ -1582,6 +1581,7 @@ Status HybridModelBuilder::CheckAicpuOp() {
aicpu_optype_list.assign(aicpu_optype_set.begin(), aicpu_optype_set.end());
aicpu_tf_optype_list.assign(aicpu_tf_optype_set.begin(), aicpu_tf_optype_set.end());
GE_CHK_STATUS_RET(ModelManager::GetInstance()->LaunchKernelCheckAicpuOp(aicpu_optype_list, aicpu_tf_optype_list), "Launch check aicpu op type failed.");
return SUCCESS;
}
} // namespace hybrid
} // namespace ge