|
|
|
@@ -306,6 +306,7 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts( |
|
|
|
FuncGraphPtr paserTfFuction() { return nullptr; } |
|
|
|
FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, |
|
|
|
const QuantType &quantType) { |
|
|
|
NoSupportOp::GetInstance()->SetFmkType("TF"); |
|
|
|
auto status = ValidateFileStr(modelFile, ".pb"); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb"; |
|
|
|
@@ -321,7 +322,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin |
|
|
|
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get()); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
anf_root_graph_ = std::make_shared<FuncGraph>(); |
|
|
|
@@ -346,13 +347,13 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin |
|
|
|
for (int i = 0; i < tf_root_graph_->node_size(); i++) { |
|
|
|
auto &node_def = tf_root_graph_->node(i); |
|
|
|
status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); |
|
|
|
if (status != RET_OK) { |
|
|
|
success_flag = false; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!success_flag) { |
|
|
|
MS_LOG(ERROR) << "Convert ops failed."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
status = ConvertRootGraphOutputs(); |
|
|
|
@@ -376,6 +377,7 @@ STATUS TFModelParser::ConvertSubgraph() { |
|
|
|
auto subgraph_size = graph_def_liarary.function_size(); |
|
|
|
std::map<CNodePtr, FuncGraphPtr> while_cond_map; |
|
|
|
std::map<CNodePtr, FuncGraphPtr> while_body_map; |
|
|
|
bool success_flag = true; |
|
|
|
for (int i = 0; i < subgraph_size; i++) { |
|
|
|
auto &tf_sub_fuction = graph_def_liarary.function(i); |
|
|
|
auto &tf_sub_signature = tf_sub_fuction.signature(); |
|
|
|
@@ -421,12 +423,16 @@ STATUS TFModelParser::ConvertSubgraph() { |
|
|
|
for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) { |
|
|
|
auto &node_def = tf_sub_fuction.node_def(j); |
|
|
|
status = ConvertOps(node_def, tf_sub_node_map, sub_func_graph, &anf_sub_node_map); |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Convert subgraph ops failed."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); |
|
|
|
return RET_ERROR; |
|
|
|
success_flag = false; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!success_flag) { |
|
|
|
MS_LOG(ERROR) << "Convert subgraph is failed."; |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
// convert subgraph outputs |
|
|
|
std::vector<AnfNodePtr> sub_output_nodes; |
|
|
|
@@ -483,6 +489,10 @@ STATUS TFModelParser::ConvertSubgraph() { |
|
|
|
|
|
|
|
MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name; |
|
|
|
} |
|
|
|
if (!success_flag) { |
|
|
|
MS_LOG(ERROR) << "Convert subgraph is failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto status = WhileNodePostProcess(while_cond_map, while_body_map); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "while node post process failed"; |
|
|
|
@@ -593,7 +603,6 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, |
|
|
|
std::unordered_map<std::string, AnfNodePtr> *anf_node_map) { |
|
|
|
MS_ASSERT(node_def != nullptr); |
|
|
|
MS_ASSERT(func_graph_ptr != nullptr); |
|
|
|
NoSupportOp::GetInstance()->SetFmkType("TF"); |
|
|
|
STATUS status = RET_OK; |
|
|
|
const auto &op_type = node_def.op(); |
|
|
|
if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") { |
|
|
|
@@ -645,8 +654,6 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, |
|
|
|
status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return status; |
|
|
|
} |
|
|
|
|