Browse Source

st parser

pull/422/head
jwx930962 4 years ago
parent
commit
16923a1ce4
1 changed files with 119 additions and 2 deletions
  1. +119
    -2
      tests/st/testcase/test_tensorflow_parser.cc

+ 119
- 2
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -65,6 +65,8 @@
#include "parser/common/model_saver.h"
#include "framework/omg/parser/parser_api.h"
#include "parser/common/parser_fp16_t.h"
#include "parser/common/op_parser_factory.h"
#include "parser/common/prototype_pass_manager.h"
#undef protected
#undef private

@@ -3101,6 +3103,46 @@ TEST_F(STestTensorflowParser, tensorflow_OptimizeConstNodes4CustomOp_test)
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(STestTensorflowParser, OptimizeConstNodes4CustomOp_success)
{
GraphDef graph;
auto bn = AddNode(graph, "FusedBatchNormV3", "FusedBatchNormV3_0");
auto bn_grad = AddNode(graph, "FusedBatchNormGradV3", "FusedBatchNormGradV3_0");

AddInput(bn, bn_grad, 0);
AddInput(bn, bn_grad, 1);
AddInput(bn, bn_grad, 2);
AddInput(bn, bn_grad, 3);
AddInput(bn, bn_grad, 5);
AddInput(bn, bn_grad, 5);

GraphDef* graphDef = &graph;
int before_bn_grad_input_size = bn_grad->input_size();
ASSERT_EQ(before_bn_grad_input_size, 6);

ModelParserFactory* factory = ModelParserFactory::Instance();
shared_ptr<domi::ModelParser> model_parser= factory->CreateModelParser(domi::TENSORFLOW);
ge::TensorFlowModelParser tensorflow_parser;

Status ret = tensorflow_parser.OptimizeConstNodes4CustomOp(graphDef);
int after_bn_grad_input_size = bn_grad->input_size();
ASSERT_EQ(after_bn_grad_input_size, 6);
ASSERT_EQ(ret, domi::SUCCESS);

REGISTER_CUSTOM_OP("BatchNormGrad")
.FrameworkType(domi::TENSORFLOW)
.OriginOpType({"FusedBatchNormGradV3", "FusedBatchNormGradV2", "FusedBatchNormGrad"})
.ParseParamsFn(AutoMappingFn)
.DelInputWithOriginalType(5, "FusedBatchNormGradV3")
.ImplyType(ImplyType::TVM);
register_tbe_op();

ret = tensorflow_parser.OptimizeConstNodes4CustomOp(graphDef);
after_bn_grad_input_size = bn_grad->input_size();
ASSERT_EQ(after_bn_grad_input_size, 6);
ASSERT_EQ(ret, domi::SUCCESS);
}

TEST_F(STestTensorflowParser, tensorflow_ParseOpParams_test)
{
TensorFlowModelParser model_parser;
@@ -3483,11 +3525,16 @@ TEST_F(STestTensorflowParser, tensorflow_tbe_tfplugin_loader_test)
std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/");
caseDir = caseDir.substr(0, idx);
std::string proto_file = caseDir + "/origin_models/caffe.proto";
std::string proto_file = caseDir + "/origin_models/";
std::string path = proto_file;
std::string caffe_parser_path = path;
pluginLoad.FindParserSo(path, fileList, caffe_parser_path);

setenv("ASCEND_OPP_PATH", "aaa", 1);
std::string customop_path = "";
pluginLoad.GetCustomOpPath(customop_path);
ASSERT_EQ(customop_path, "aaa/framework/custom/:aaa/framework/built-in/tensorflow");

Status ret = pluginLoad.Finalize();
EXPECT_EQ(ret, SUCCESS);
}
@@ -3551,9 +3598,14 @@ TEST_F(STestTensorflowParser, tensorflow_ReadBytesFromBinaryFile_test)
ret = parser::ReadBytesFromBinaryFile(file_name, &buffer, length);
EXPECT_EQ(ret, true);

const char *path = nullptr;
char path[4096 + 1] = { 0 };
memset(path, 'a', 4096);
std::string realPath = parser::RealPath(path);
EXPECT_EQ(realPath, "");

const char *real_path = nullptr;
realPath = parser::RealPath(real_path);
EXPECT_EQ(realPath, "");
}

TEST_F(STestTensorflowParser, tensorflow_AclGrphParseUtil_ParseAclInputFp16Nodes_test)
@@ -3621,6 +3673,24 @@ TEST_F(STestTensorflowParser, create_weights_parser_failed)
ModelParserFactory *modelFactory = ModelParserFactory::Instance();
shared_ptr<ModelParser> model_parser = modelFactory->CreateModelParser(FRAMEWORK_RESERVED);
ASSERT_TRUE(NULL == model_parser);

std::shared_ptr<OpParserFactory> parserFactory = OpParserFactory::Instance(domi::FrameworkType::CAFFE);
std::shared_ptr<OpParser> fusion_op_parser = parserFactory->CreateFusionOpParser(ge::parser::DATA);
ASSERT_TRUE(NULL == fusion_op_parser);

std::shared_ptr<OpParser> op_parser = parserFactory->CreateOpParser("10");
ASSERT_TRUE(NULL == op_parser);
}

TEST_F(STestTensorflowParser, custom_parser_adapter_register)
{
using PARSER_CREATOR_FN = std::function<std::shared_ptr<OpParser>(void)>;
PARSER_CREATOR_FN func = CustomParserAdapterRegistry::Instance()->GetCreateFunc(domi::TENSORFLOW);
CustomParserAdapterRegistry::Instance()->Register(domi::TENSORFLOW, func);
CustomParserAdapterRegistry::Instance()->Register(domi::TENSORFLOW, func);

func = CustomParserAdapterRegistry::Instance()->GetCreateFunc(domi::FRAMEWORK_RESERVED);
ASSERT_EQ(nullptr, func);
}

TEST_F(STestTensorflowParser, tensorflow_parser_api_test)
@@ -3650,6 +3720,53 @@ TEST_F(STestTensorflowParser, tensorflow_FP16_parser_test)
fp16.ToInt32();
fp16.ToUInt32();
fp16.IsInf();
fp16.operator+(fp16);
fp16.operator-(fp16);
fp16.operator*(fp16);
fp16.operator/(fp16);
fp16.operator+=(fp16);
fp16.operator-=(fp16);
fp16.operator*=(fp16);
fp16.operator/=(fp16);
fp16.operator==(fp16);
fp16.operator!=(fp16);
fp16.operator>(fp16);
fp16.operator>=(fp16);
fp16.operator<(fp16);
fp16.operator<=(fp16);
fp16.operator=(fp16);

float f_val = 0.1;
fp16.operator=(f_val);

double d_val = 0.2;
fp16.operator=(d_val);

int8_t i_val = 1;
fp16.operator=(i_val);

uint8_t ui_val = 2;
fp16.operator=(ui_val);

int16_t i_vals = 1;
fp16.operator=(i_vals);

uint16_t ui16_val = 1;
fp16.operator=(ui16_val);
ui16_val = 0;
fp16.operator=(ui16_val);
ui16_val = 100000;
fp16.operator=(ui16_val);

int32_t i32_val = 0;
fp16.operator=(i32_val);
i32_val = 1;
fp16.operator=(i32_val);

uint32_t ui32_val = 0;
fp16.operator=(ui32_val);
ui32_val = 1;
fp16.operator=(ui32_val);
}

} // namespace ge

Loading…
Cancel
Save