Browse Source

increase parser st

pull/403/head
jwx930962 4 years ago
parent
commit
fc4aabf874
6 changed files with 83 additions and 2 deletions
  1. +1
    -1
      metadef
  2. +26
    -0
      tests/st/testcase/origin_models/conv2d_depthwise_pb_gen.py
  3. BIN
      tests/st/testcase/origin_models/model.pb
  4. +27
    -0
      tests/st/testcase/origin_models/test_conv2d_pb_gen.py
  5. BIN
      tests/st/testcase/origin_models/test_depth_wise_conv2d.pb
  6. +29
    -1
      tests/st/testcase/test_tensorflow_parser.cc

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 326ecbb2b4837699aa674cc30e9b9956e4fd364d
Subproject commit 8ad7cbd3c18d322381583d75c32906f8374a348d

+ 26
- 0
tests/st/testcase/origin_models/conv2d_depthwise_pb_gen.py View File

@@ -0,0 +1,26 @@
import tensorflow as tf
import os

pb_file_path = os.getcwd()

with tf.compat.v1.Session(graph=tf.Graph()) as sess:
# NHWC
fmap_shape = [17, 101, 101, 17]
filter_size = [5, 5, 17, 1]
dy_shape = [17, 49, 49, 17]
strideh, stridew = [2, 2]
padding = 'VALID'
tensor_x1 = tf.compat.v1.placeholder(dtype="float16", shape=fmap_shape)
tensor_x2 = tf.compat.v1.placeholder(dtype="float16", shape=fmap_shape)
tensor_x = tf.add(tensor_x1, tensor_x2)
tensor_dy1 = tf.compat.v1.placeholder(dtype="float16", shape=dy_shape)
tensor_dy2 = tf.compat.v1.placeholder(dtype="float16", shape=dy_shape)
tensor_dy = tf.add(tensor_dy1, tensor_dy2)
op = tf.nn.depthwise_conv2d_backprop_filter(tensor_x, filter_size, tensor_dy,
strides=[1, strideh, stridew, 1],
padding=padding,
data_format='NHWC',
dilations=[1,1,1,1])

tf.io.write_graph(sess.graph, logdir="./", name="test_depth_wise_conv2d.pb", as_text=False)

BIN
tests/st/testcase/origin_models/model.pb View File


+ 27
- 0
tests/st/testcase/origin_models/test_conv2d_pb_gen.py View File

@@ -0,0 +1,27 @@
import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

pb_file_path = os.getcwd()

def generate_conv2d_pb():
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
input_x = tf.compat.v1.placeholder(dtype="float32", shape=(1,56,56,64))
input_filter = tf.compat.v1.placeholder(dtype="float32", shape=(3,3,64,64))
op = tf.nn.conv2d(input_x, input_filter, strides=[1,1,1,1], padding=[[0,0],[1,1],[1,1],[0,0]],
data_format="NHWC", dilations=[1,1,1,1], name='conv2d_res')
tf.io.write_graph(sess.graph, logdir="./", name="conv2d.pb", as_text=False)


def generate_add_pb():
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
x = tf.compat.v1.placeholder(tf.int32, name='x')
y = tf.compat.v1.placeholder(tf.int32, name='y')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
op = tf.add(xy, b, name='op_to_store')
tf.io.write_graph(sess.graph, logdir="./", name="model.pb", as_text=False)

if __name__=='__main__':
generate_conv2d_pb()
generate_add_pb()

BIN
tests/st/testcase/origin_models/test_depth_wise_conv2d.pb View File


+ 29
- 1
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -22,7 +22,7 @@
#include "parser/common/register_tbe.h"
#include "external/parser/tensorflow_parser.h"
#include "st/parser_st_utils.h"
#include "tests/depends/ops_stub/ops_stub.h"
#include "parser/common/acl_graph_parser_util.h"

namespace ge {
class STestTensorflowParser : public testing::Test {
@@ -76,3 +76,31 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) {
EXPECT_EQ(net_out_name.at(0), "add_test_1:0");
}
} // namespace ge

TEST_F(STestTensorflowParser, tensorflow_model_Failed) {
ge::Graph graph;
std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/");
caseDir = caseDir.substr(0, idx);

std::string modelFile = caseDir + "/origin_models/model.pb";
auto status = ge::aclgrphParseTensorFlow(modelFile.c_str(), graph);
EXPECT_EQ(status, ge::GRAPH_FAILED);

modelFile = caseDir + "/origin_models/test_depth_wise_conv2d.pb";
status = ge::aclgrphParseTensorFlow(modelFile.c_str(), graph);
EXPECT_EQ(status, ge::GRAPH_FAILED);
}

TEST_F(STestTensorflowParser, tensorflow_model_not_exist) {
ge::Graph graph;
std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/");
caseDir = caseDir.substr(0, idx);

// model file is not exist
std::string modelFile = caseDir + "/origin_models/conv2d_explicit1_pad.pb";
auto status = ge::aclgrphParseTensorFlow(modelFile.c_str(), graph);
EXPECT_EQ(status, ge::GRAPH_FAILED);
}
} // namespace ge

Loading…
Cancel
Save