| @@ -2563,6 +2563,7 @@ Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_m | |||||
| domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name]; | domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name]; | ||||
| GE_CHECK_NOTNULL(output_node_def); | GE_CHECK_NOTNULL(output_node_def); | ||||
| auto inputs = output_node_def->mutable_input(); | auto inputs = output_node_def->mutable_input(); | ||||
| std::vector<std::string> added_inputs; | |||||
| for (auto &input : *inputs) { | for (auto &input : *inputs) { | ||||
| string node_name; | string node_name; | ||||
| bool is_control = false; | bool is_control = false; | ||||
| @@ -2596,12 +2597,15 @@ Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_m | |||||
| } | } | ||||
| } | } | ||||
| if (!is_exist_input) { | if (!is_exist_input) { | ||||
| output_node_def->add_input("^" + item); | |||||
| GELOGD("Optimize Snapshot node, dest:%s, set control input:%s.", output_node_name.c_str(), item.c_str()); | |||||
| added_inputs.push_back("^" + item); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| for (std::string added_input : added_inputs) { | |||||
| GELOGD("Optimize Snapshot node, dest:%s, set control input:%s.", output_node_name.c_str(), added_input.c_str()); | |||||
| output_node_def->add_input(added_input); | |||||
| } | |||||
| } | } | ||||
| // Clear the input of snapshot and become an isolated node | // Clear the input of snapshot and become an isolated node | ||||
| curr_mode_def->clear_input(); | curr_mode_def->clear_input(); | ||||
| @@ -98,7 +98,7 @@ void ErrorManager::SetStage(const std::string &first_stage, const std::string &s | |||||
| } | } | ||||
| struct error_message::Context &ErrorManager::GetErrorManagerContext() { | struct error_message::Context &ErrorManager::GetErrorManagerContext() { | ||||
| struct error_message::Context error_context; | |||||
| static struct error_message::Context error_context; | |||||
| return error_context; | return error_context; | ||||
| } | } | ||||
| @@ -2387,5 +2387,23 @@ TEST_F(STestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test) | |||||
| Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize); | Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize); | ||||
| EXPECT_EQ(ret, ge::PARAM_INVALID); | EXPECT_EQ(ret, ge::PARAM_INVALID); | ||||
| } | } | ||||
| TEST_F(STestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { | |||||
| std::string caseDir = __FILE__; | |||||
| std::size_t idx = caseDir.find_last_of("/"); | |||||
| caseDir = caseDir.substr(0, idx); | |||||
| const std::string root_proto = caseDir + "/origin_models/test_snapshot.pb"; | |||||
| domi::tensorflow::GraphDef graphDef; | |||||
| bool protoRet = | |||||
| parser::ReadProtoFromBinaryFile(root_proto.c_str(), &graphDef); | |||||
| ASSERT_EQ(protoRet, true); | |||||
| TensorFlowModelParser tensorflow_parser; | |||||
| ge::ComputeGraphPtr root_graph = | |||||
| ge::parser::MakeShared<ge::ComputeGraph>("tmp_graph"); | |||||
| Status ret = tensorflow_parser.ParseProto( | |||||
| reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph); | |||||
| EXPECT_EQ(FAILED, ret); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -188,4 +188,33 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_with_external_graph) { | |||||
| ret = TensorFlowModelParser::AddExternalGraph(root_graph); | ret = TensorFlowModelParser::AddExternalGraph(root_graph); | ||||
| EXPECT_EQ(ret, INTERNAL_ERROR); | EXPECT_EQ(ret, INTERNAL_ERROR); | ||||
| } | } | ||||
| TEST_F(UtestTensorflowParser, optimize_snapshot) { | |||||
| domi::tensorflow::GraphDef graph_def; | |||||
| auto mul_node = graph_def.add_node(); | |||||
| mul_node->set_name("optimizer/Mul"); | |||||
| mul_node->set_op("Mul"); | |||||
| mul_node->add_input("Snapshot:0"); | |||||
| auto snapshot_node = graph_def.add_node(); | |||||
| snapshot_node->set_name("Snapshot"); | |||||
| snapshot_node->set_op("Snapshot"); | |||||
| snapshot_node->add_input("loss_scale/read:0"); | |||||
| snapshot_node->add_input("^ShuffleNet/AssignMovingAvg"); | |||||
| auto identity_node = graph_def.add_node(); | |||||
| identity_node->set_name("loss_scale/read"); | |||||
| identity_node->set_op("Identity"); | |||||
| identity_node->add_input("loss_scale/ref:0"); | |||||
| auto assign_node = graph_def.add_node(); | |||||
| assign_node->set_name("ShuffleNet/AssignMovingAvg"); | |||||
| assign_node->set_op("AssignSub"); | |||||
| assign_node->add_input("ShuffleNet/moving_mean:0"); | |||||
| Status ret = TensorFlowModelParser().GraphDefOptimize(&graph_def); | |||||
| EXPECT_EQ(ret, ge::SUCCESS); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||