Browse Source

parser ut

pull/434/head
jwx930962 4 years ago
parent
commit
6a0d524fe1
1 changed files with 72 additions and 6 deletions
  1. +72
    -6
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 72
- 6
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -76,6 +76,7 @@
#include "parser/tensorflow/graph_optimizer.h"
#include "metadef/inc/register/scope/scope_pass_registry_impl.h"
#include "register/scope/scope_fusion_pass_register.h"
#include "common/op_map.h"
#undef protected
#undef private

@@ -90,6 +91,15 @@ using namespace google::protobuf;
static const string GRAPH_DEFAULT_NAME = "default";

namespace ge {
struct DelTransposeInfo {
domi::tensorflow::NodeDef *node_def; // transpose
domi::tensorflow::NodeDef *nextNodeDef; // transpose --> [next]
int inputIdx;
};

Status GetTransposeInfo(GraphDef *graph_def, std::map<std::string, std::string> &softmaxInfo,
std::map<std::string, DelTransposeInfo> &transposeInfo);

class UtestTensorflowParser : public testing::Test {
protected:
void SetUp() {
@@ -161,12 +171,6 @@ void UtestTensorflowParser::RegisterCustomOp() {
domi::OpRegistry::Instance()->registrationDatas.clear();
}

struct DelTransposeInfo {
domi::tensorflow::NodeDef *node_def; // transpose
domi::tensorflow::NodeDef *nextNodeDef; // transpose --> [next]
int inputIdx;
};

namespace {
NodeDef* AddNode(GraphDef& graph, string type, string name) {
NodeDef* nodeDef = graph.add_node();
@@ -4492,6 +4496,8 @@ TEST_F(UtestTensorflowParser, tensorflow_EraseNormalOpOutputIfChild)

ret = modelParser.EraseNormalOpOutputIfChild(scope_graph, op_node_name, normal_op_node_context);
EXPECT_EQ(ret, SUCCESS);

delete node;
}

TEST_F(UtestTensorflowParser, tensorflow_UpdateNormalOpContext)
@@ -4552,4 +4558,64 @@ TEST_F(UtestTensorflowParser, tensorflow_OptimizeTranspose)
delete info.nextNodeDef;
}

TEST_F(UtestTensorflowParser, tensorflow_SoftmaxAddAttr)
{
TensorFlowModelParser modelParser;
domi::tensorflow::GraphDef graph_def;
graph_def.add_node();
modelParser.SoftmaxAddAttr(&graph_def);
}

TEST_F(UtestTensorflowParser, tensorflow_InferInputFormats)
{
domiTensorFormat_t ret;
TensorFlowModelParser modelParser;

GetParserContext().format = DOMI_TENSOR_RESERVED;
NodeDef *node = MallocNodeDef("node", "DATA");
modelParser.nodedef_map_["node"] = node;
tensorflow_op_map["DATA"] = "node";
ret = modelParser.InferInputFormats();
EXPECT_EQ(ret, domi::DOMI_TENSOR_NHWC);
delete node;
NodeDef* node1 = nullptr;
modelParser.nodedef_map_["node"] = node1;

ret = modelParser.InferInputFormats();
EXPECT_EQ(ret, domi::DOMI_TENSOR_RESERVED);
}

TEST_F(UtestTensorflowParser, tensorflow_GetTransposeInfo)
{
Status ret;
DelTransposeInfo info;
tensorflow::GraphDef *graph = new tensorflow::GraphDef();
std::map<std::string, std::string> softmaxInfo = {{"ge", "ge"}};

info.node_def = new NodeDef();
info.nextNodeDef = new NodeDef();
info.node_def->add_input("ge");
info.nextNodeDef->add_input("ge");
info.inputIdx = 0;

NodeDef *node = graph->add_node();
node->set_op("Transpose");

std::map<std::string, DelTransposeInfo> transposeInfo = {{"Softmax", info}};
ret = ge::GetTransposeInfo(graph, softmaxInfo, transposeInfo);
EXPECT_EQ(ret, SUCCESS);

node->set_op("Softmax");
node->set_name("Softmax");
node->add_input("Softmax");
ret = ge::GetTransposeInfo(graph, softmaxInfo, transposeInfo);
EXPECT_EQ(ret, SUCCESS);

delete info.node_def;
delete info.nextNodeDef;
delete graph;
}

} // namespace ge

Loading…
Cancel
Save