From 33ef17e80041627a82f9dab02eee97d58a6ca178 Mon Sep 17 00:00:00 2001 From: zk <694972388@qq.com> Date: Mon, 13 Jun 2022 10:50:21 +0800 Subject: [PATCH] add ut --- .../caffe_parser_unittest.cc | 51 +++++++++++++++++++ .../common/acl_graph_parser_unittest.cc | 1 + 2 files changed, 52 insertions(+) diff --git a/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc b/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc index 846a209..cbfe7c5 100755 --- a/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc +++ b/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc @@ -1194,4 +1194,55 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ReorderInput_test) modelParser.ReorderInput(net); } +TEST_F(UtestCaffeParser, CaffeOpParser_ParseParms_test) +{ + CaffeOpParser parser; + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string caffe_proto = case_dir + "/../../../../../metadef/proto/caffe/"; + google::protobuf::compiler::DiskSourceTree sourceTree; + sourceTree.MapPath("project_root", caffe_proto); + google::protobuf::compiler::Importer importer(&sourceTree, nullptr); + importer.Import("project_root/caffe.proto"); + auto descriptor = importer.pool()->FindMessageTypeByName("domi.caffe.LayerParameter"); + ge::OpDescPtr op_desc_src = std::make_shared("Abs", "AbsVal"); + google::protobuf::DynamicMessageFactory factory; + const google::protobuf::Message *proto = factory.GetPrototype(descriptor); + const google::protobuf::Message *message = proto->New(); + ge::Operator op_src = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); + Status ret = parser.ParseParams(message, op_src); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestCaffeParser, CaffeModelParser_Constructor_and_delete) +{ + CaffeModelParser modelParser; + domi::caffe::NetParameter net; + net.add_input("111"); + bool input_data_flag = true; + net.add_input_shape(); + Status ret = modelParser.ParseInput(net, input_data_flag); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestCaffeParser, ParseFromMemory_success_graph) +{ + std::string caseDir = __FILE__; + std::size_t idx = caseDir.find_last_of("/"); + caseDir = caseDir.substr(0, idx); + std::string modelFile = caseDir + "/caffe_model/caffe_add.pbtxt"; + std::string weight_file = caseDir + "/caffe_model/caffe_add.caffemodel"; + + const char* tmp_tf_pb_model = modelFile.c_str(); + const char* tmp_tf_weight_model = weight_file.c_str(); + ge::Graph graph; + + Status ret = ge::aclgrphParseCaffe(modelFile.c_str(), weight_file.c_str(), graph); + CaffeModelParser modelParser; + MemBuffer* memBuffer1 = ParerUTestsUtils::MemBufferFromFile(tmp_tf_pb_model); + ret = modelParser.ParseFromMemory((char*)memBuffer1->data, memBuffer1->size, graph); + EXPECT_EQ(ret, SUCCESS); + delete memBuffer1; +} + } // namespace ge diff --git a/tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc b/tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc index ec3ac01..e452661 100755 --- a/tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc +++ b/tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc @@ -344,6 +344,7 @@ TEST_F(UtestAclGraphParser, test_operatoreq) } TEST_F(UtestAclGraphParser, test_pre_checker) { + TBEPluginLoader tbe_plugin; PreChecker::Instance().fmk_op_types_ = nullptr; const char* str = "iiii"; PreChecker::OpId id = str;