Browse Source

parset st

pull/424/head
y00500818 4 years ago
parent
commit
f0d9208600
5 changed files with 152 additions and 0 deletions
  1. +26
    -0
      tests/st/parser_st_utils.cc
  2. +2
    -0
      tests/st/parser_st_utils.h
  3. +36
    -0
      tests/st/testcase/origin_models/caffe_add.caffemodel.txt
  4. +28
    -0
      tests/st/testcase/origin_models/caffe_add.pbtxt
  5. +60
    -0
      tests/st/testcase/test_caffe_parser.cc

+ 26
- 0
tests/st/parser_st_utils.cc View File

@@ -17,6 +17,11 @@
#include "st/parser_st_utils.h"
#include "framework/common/debug/ge_log.h"
#include <limits.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <fstream>


namespace ge {
void ParerSTestsUtils::ClearParserInnerCtx() {
@@ -105,4 +110,25 @@ MemBuffer* ParerSTestsUtils::MemBufferFromFile(const char *path) {
return membuf;
}

bool ParerSTestsUtils::ReadProtoFromText(const char *file, google::protobuf::Message *message) {
std::ifstream fs(file);
if (!fs.is_open()) {
return false;
}
google::protobuf::io::IstreamInputStream input(&fs);
bool ret = google::protobuf::TextFormat::Parse(&input, message);

fs.close();
return ret;
}

void ParerSTestsUtils::WriteProtoToBinaryFile(const google::protobuf::Message &proto, const char *filename) {
size_t size = proto.ByteSizeLong();
char *buf = new char[size];
proto.SerializeToArray(buf, size);
std::ofstream out(filename);
out.write(buf, size);
out.close();
delete[] buf;
}
} // namespace ge

+ 2
- 0
tests/st/parser_st_utils.h View File

@@ -29,6 +29,8 @@ class ParerSTestsUtils {
public:
static void ClearParserInnerCtx();
static MemBuffer* MemBufferFromFile(const char *path);
static bool ReadProtoFromText(const char *file, google::protobuf::Message *message);
static void WriteProtoToBinaryFile(const google::protobuf::Message &proto, const char *filename);
};
} // namespace ge



+ 36
- 0
tests/st/testcase/origin_models/caffe_add.caffemodel.txt View File

@@ -0,0 +1,36 @@
name: "TestAdd"
input: "data"
layer {
name: "data"
type: "Input"
top: "data"
input_param { shape: { dim: 3} }
}

layer {
name: "const"
type: "Input"
top: "const"
input_param { shape: { dim: 3} }
blobs {
data: 1
data: 2
data: 3
shape {
dim: 3
}
}
}

layer {
name: "reshape"
type: "Reshape"
bottom: "data"
bottom: "const"
top: "reshpae_out"
reshape_param {
shape {
dim: 3
}
}
}

+ 28
- 0
tests/st/testcase/origin_models/caffe_add.pbtxt View File

@@ -0,0 +1,28 @@
name: "TestAdd"
input: "data"
layer {
name: "data"
type: "Input"
top: "data"
input_param { shape: { dim: 3} }
}

layer {
name: "const"
type: "Input"
top: "const"
input_param { shape: { dim: 3} }
}

layer {
name: "reshape"
type: "Reshape"
bottom: "data"
bottom: "const"
top: "reshpae_out"
reshape_param {
shape {
dim: 3
}
}
}

+ 60
- 0
tests/st/testcase/test_caffe_parser.cc View File

@@ -25,6 +25,8 @@
#include "st/parser_st_utils.h"
#include "external/ge/ge_api_types.h"
#include "tests/depends/ops_stub/ops_stub.h"
#include "proto/caffe/caffe.pb.h"
#include "parser/caffe/caffe_parser.h"

namespace ge {
class STestCaffeParser : public testing::Test {
@@ -87,4 +89,62 @@ TEST_F(STestCaffeParser, caffe_parser_user_output_with_default) {
EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out");
}

TEST_F(STestCaffeParser, acal_caffe_parser) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/caffe_add.pbtxt";
std::string weight_file_txt = case_dir + "/origin_models/caffe_add.caffemodel.txt";
std::string weight_file = case_dir + "/origin_models/caffe_add.caffemodel";

domi::caffe::NetParameter proto;
EXPECT_EQ(ParerSTestsUtils::ReadProtoFromText(weight_file_txt.c_str(), &proto), true);
ParerSTestsUtils::WriteProtoToBinaryFile(proto, weight_file.c_str());

ge::GetParserContext().caffe_proto_path = case_dir + "/../../../../metadef/proto/caffe/caffe.proto";

std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, GRAPH_FAILED);
ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), graph);
EXPECT_EQ(ret, GRAPH_FAILED);
}

TEST_F(STestCaffeParser, modelparser_parsefrommemory_success)
{
std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/");
caseDir = caseDir.substr(0, idx);
std::string modelFile = caseDir + "/origin_models/caffe_add.pbtxt";
const char* tmp_tf_pb_model = modelFile.c_str();
ge::Graph graph;

ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
CaffeModelParser modelParser;
MemBuffer* memBuffer = ParerSTestsUtils::MemBufferFromFile(tmp_tf_pb_model);
auto ret = modelParser.ParseFromMemory((char*)memBuffer->data, memBuffer->size, compute_graph);
free(memBuffer->data);
delete memBuffer;
EXPECT_EQ(ret, GRAPH_FAILED);
}

TEST_F(STestCaffeParser, caffe_parser_to_json) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/caffe_add.pbtxt";
std::map<ge::AscendString, ge::AscendString> parser_params;
CaffeModelParser caffe_parser;

const char *json_file = "tmp.json";
auto ret = caffe_parser.ToJson(model_file.c_str(), json_file);
EXPECT_EQ(ret, SUCCESS);

const char *json_null = nullptr;
ret = caffe_parser.ToJson(model_file.c_str(), json_null);
EXPECT_EQ(ret, FAILED);
const char *model_null = nullptr;
ret = caffe_parser.ToJson(model_null, json_null);
EXPECT_EQ(ret, FAILED);
}

} // namespace ge

Loading…
Cancel
Save