| @@ -217,7 +217,7 @@ std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) { | |||
| std::string unsupported_reason; | |||
| // It will be replaced by engine' checksupport | |||
| uint64_t start_time = GetCurrentTimestamp(); | |||
| if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { | |||
| if (kernel_info_store->second->CheckSupported(node_ptr, unsupported_reason)) { | |||
| checksupport_cost_[kernel_name] += GetCurrentTimestamp() - start_time; | |||
| op_desc->SetOpEngineName(it.engine); | |||
| op_desc->SetOpKernelLibName(kernel_name); | |||
| @@ -66,7 +66,8 @@ bool ContainsDynamicInpus(const ge::OpDesc &op_desc) { | |||
| } // namespace | |||
| namespace ge { | |||
| static Status CheckEngineTypeSupport(const OpDescPtr &op_desc, OpEngineType engine_type) { | |||
| static Status CheckEngineTypeSupport(const Nodeptr &node, OpEngineType engine_type) { | |||
| const OpDescPtr &op_desc = node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); | |||
| if (engine_type == ENGINE_SYS) { | |||
| GELOGI("CheckEngineType: use default engine."); | |||
| @@ -123,7 +124,7 @@ static Status CheckEngineTypeSupport(const OpDescPtr &op_desc, OpEngineType engi | |||
| auto kernel_info_store = kernel_map.find(kernel_name); | |||
| if (kernel_info_store != kernel_map.end()) { | |||
| std::string unsupported_reason; | |||
| if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { | |||
| if (kernel_info_store->second->CheckSupported(node, unsupported_reason)) { | |||
| op_desc->SetOpEngineName(op_engine_name); | |||
| op_desc->SetOpKernelLibName(kernel_name); | |||
| GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), | |||
| @@ -692,22 +693,26 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||
| OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc); | |||
| GE_CHECK_NOTNULL(op_desc_tmp); | |||
| // 1. check engine type when compile online | |||
| // 1. Create ComputeGraph. | |||
| string name = ge::CurrentTimeInStr() + "_" + model_file_name; | |||
| Graph graph; | |||
| if (BuildSingleOpGraph(op_desc, inputs, outputs, name, graph) != ge::SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "make graph fail."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| // 2. check engine type when compile online | |||
| if (model_file_name == kFileNameSuffix) { | |||
| Status ret = CheckEngineTypeSupport(op_desc, engine_type); | |||
| auto comp_graph = GraphUtils::GetComputeGraph(graph); | |||
| GE_CHECK_NOTNULL(comp_graph); | |||
| auto node = comp_graph->FindNode(op_desc->GetName()); | |||
| Status ret = CheckEngineTypeSupport(node, engine_type); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "check engine type failed."); | |||
| return ret; | |||
| } | |||
| } | |||
| // 2. Create ComputeGraph. | |||
| string name = ge::CurrentTimeInStr() + "_" + model_file_name; | |||
| Graph graph; | |||
| if (BuildSingleOpGraph(op_desc, inputs, outputs, name, graph) != ge::SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "make graph fail."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GELOGI("ATC parser success in single op build."); | |||
| GeRootModelPtr ge_root_model = nullptr; | |||
| @@ -167,7 +167,7 @@ bool CastTranslatePass::IsOpSupportedOptimize(NodePtr &cast_node, NodePtr &trans | |||
| trans_op_outdesc->SetDataType(cast_out_datatype); | |||
| } | |||
| if (!TranslateCheckAccuracySupported(trans_op_desc)) { | |||
| if (!TranslateCheckAccuracySupported(trans_node)) { | |||
| if (is_src_cast) { | |||
| trans_op_desc->MutableInputDesc(0)->SetDataType(trans_in_datatype); | |||
| } else { | |||
| @@ -271,7 +271,8 @@ Status CastTranslatePass::FuseDstNTranslates(NodePtr &node) { | |||
| return SUCCESS; | |||
| } | |||
| bool CastTranslatePass::TranslateCheckAccuracySupported(const OpDescPtr &op_desc) { | |||
| bool CastTranslatePass::TranslateCheckAccuracySupported(NodePtr &node) { | |||
| const OpDescPtr &op_desc = node->GetOpDesc(); | |||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
| if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | |||
| GELOGW("GE is not initialized or is finalized."); | |||
| @@ -293,7 +294,7 @@ bool CastTranslatePass::TranslateCheckAccuracySupported(const OpDescPtr &op_desc | |||
| auto kernel_info_store = kernel_map.find(kernel_name); | |||
| if (kernel_info_store != kernel_map.end()) { | |||
| if (kernel_info_store->second != nullptr && | |||
| kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason)) { | |||
| kernel_info_store->second->CheckAccuracySupported(node, unsupported_reason)) { | |||
| return true; | |||
| } | |||
| } | |||
| @@ -35,7 +35,7 @@ class CastTranslatePass : public BaseNodePass { | |||
| bool IsOpSupportedOptimize(NodePtr &cast_node, NodePtr &trans_node, bool &is_src_cast); | |||
| bool CheckOpSupportOptimize(NodePtr &node, bool &is_src_cast); | |||
| Status FuseDstNTranslates(NodePtr &node); | |||
| bool TranslateCheckAccuracySupported(const OpDescPtr &op_desc); | |||
| bool TranslateCheckAccuracySupported(NodePtr &node); | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_PASSES_CAST_TRANSLATE_PASS_H_ | |||
| @@ -110,7 +110,7 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: | |||
| return ge::GE_GRAPH_PARAM_NULLPTR; | |||
| } | |||
| // begin accuracy supported check | |||
| if (!CheckAccuracySupport(kernel_info, instance, op_desc)) { | |||
| if (!CheckAccuracySupport(kernel_info, instance, node)) { | |||
| // if check accuracy support failed , try to go to other engine. | |||
| GELOGD("Check Accuracy Supported return not support, node name is %s. Try to go to other engine.", | |||
| op_desc->GetName().c_str()); | |||
| @@ -123,7 +123,7 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: | |||
| continue; | |||
| } | |||
| OpsKernelInfoStorePtr tmp_kernel_info = it->second; | |||
| if (CheckAccuracySupport(tmp_kernel_info, instance, op_desc)) { | |||
| if (CheckAccuracySupport(tmp_kernel_info, instance, node)) { | |||
| kernel_lib_name = tmp_kernel_name; | |||
| GELOGD("Find kernel lib %s support node:%s, type:%s , get kernel lib success.", tmp_kernel_name.c_str(), | |||
| node->GetName().c_str(), op_desc->GetType().c_str()); | |||
| @@ -138,14 +138,9 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: | |||
| } | |||
| bool CompileNodesPass::CheckAccuracySupport(const OpsKernelInfoStorePtr &kernel_info, | |||
| const std::shared_ptr<GELib> instance, OpDescPtr &op_desc) { | |||
| auto ge_desc = MakeShared<ge::OpDescPtr>(op_desc); | |||
| if (ge_desc == nullptr) { | |||
| GELOGE(GE_GRAPH_MEMORY_ALLOC_FAILED, "Fail to malloc op desc."); | |||
| return false; | |||
| } | |||
| const std::shared_ptr<GELib> instance, const NodePtr &node) { | |||
| string reason; | |||
| if (!(kernel_info->CheckAccuracySupported(*ge_desc, reason, true))) { | |||
| if (!(kernel_info->CheckAccuracySupported(node, reason, true))) { | |||
| return false; | |||
| } | |||
| return true; | |||
| @@ -39,7 +39,7 @@ class CompileNodesPass : public GraphPass { | |||
| private: | |||
| graphStatus GetSupportedKernel(const NodePtr &node, const std::shared_ptr<GELib> instance, string &kernel_lib_name); | |||
| bool CheckAccuracySupport(const OpsKernelInfoStorePtr &kernel_info, const std::shared_ptr<GELib> instance, | |||
| OpDescPtr &op_desc); | |||
| const NodePtr &node); | |||
| graphStatus CompileNodes(const std::shared_ptr<GELib> instance, | |||
| std::unordered_map<string, vector<NodePtr>> &kernel_to_compile_nodes); | |||
| }; | |||
| @@ -86,7 +86,7 @@ Status TransposeTransDataPass::Run(NodePtr &node) { | |||
| if (CheckOneInAndOneOutDataAnchor(out_node)) { | |||
| return FAILED; | |||
| } | |||
| if (!FusionIfNeed(op_desc, out_op_desc)) { | |||
| if (!FusionIfNeed(node, out_op_desc)) { | |||
| continue; | |||
| } | |||
| CopyInputEdges(node, out_node); | |||
| @@ -152,7 +152,8 @@ Status TransposeTransDataPass::RemoveTranspose(NodePtr &node) { | |||
| return SUCCESS; | |||
| } | |||
| bool TransposeTransDataPass::FusionIfNeed(OpDescPtr &op_desc, OpDescPtr &transdata_op_desc) { | |||
| bool TransposeTransDataPass::FusionIfNeed(NodePtr &node, OpDescPtr &transdata_op_desc) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| GE_CHECK_NOTNULL(transdata_op_desc); | |||
| auto out_input_desc = transdata_op_desc->MutableInputDesc(0); | |||
| @@ -187,7 +188,7 @@ bool TransposeTransDataPass::FusionIfNeed(OpDescPtr &op_desc, OpDescPtr &transda | |||
| out_input_desc->SetFormat(src_format); | |||
| out_input_desc->SetShape(src_shape); | |||
| if (!TransDataCheckAccuracySupported(transdata_op_desc)) { | |||
| if (!TransDataCheckAccuracySupported(node)) { | |||
| out_input_desc->SetFormat(out_input_format); | |||
| out_input_desc->SetShape(out_input_shape); | |||
| return false; | |||
| @@ -224,7 +225,8 @@ void TransposeTransDataPass::CopyInputEdges(NodePtr &origin_node, NodePtr &new_n | |||
| GraphUtils::CopyInCtrlEdges(origin_node, new_node) != GRAPH_SUCCESS, GELOGW("Copy in ctrl edges failed"); return); | |||
| } | |||
| bool TransposeTransDataPass::TransDataCheckAccuracySupported(const OpDescPtr &op_desc) { | |||
| bool TransposeTransDataPass::TransDataCheckAccuracySupported(NodePtr &node) { | |||
| const OpDescPtr &op_desc = node->GetOpDesc(); | |||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
| if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | |||
| GELOGW("GELib not initialized"); | |||
| @@ -244,7 +246,7 @@ bool TransposeTransDataPass::TransDataCheckAccuracySupported(const OpDescPtr &op | |||
| auto &kernel_name = it.opKernelLib; | |||
| auto kernel_info_store = kernel_map.find(kernel_name); | |||
| if (kernel_info_store != kernel_map.end()) { | |||
| if (kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason, true)) { | |||
| if (kernel_info_store->second->CheckAccuracySupported(node, unsupported_reason, true)) { | |||
| return true; | |||
| } | |||
| } | |||
| @@ -26,9 +26,9 @@ class TransposeTransDataPass : public BaseNodePass { | |||
| private: | |||
| Status CheckOneInAndOneOutDataAnchor(NodePtr &node) const; | |||
| Status RemoveTranspose(NodePtr &node); | |||
| bool FusionIfNeed(OpDescPtr &op_desc, OpDescPtr &transdata_op_desc); | |||
| bool FusionIfNeed(NodePtr &node, OpDescPtr &transdata_op_desc); | |||
| void CopyInputEdges(NodePtr &origin_node, NodePtr &new_node); | |||
| bool TransDataCheckAccuracySupported(const OpDescPtr &op_desc); | |||
| bool TransDataCheckAccuracySupported(NodePtr &node); | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_PASSES_TRANSPOSE_TRANSDATA_PASS_H_ | |||