Browse Source

batch for bert train

pull/440/head
wangzhengjun 4 years ago
parent
commit
22209d49fd
5 changed files with 217 additions and 9 deletions
  1. +16
    -8
      parser/tensorflow/graph_optimizer.cc
  2. +3
    -1
      parser/tensorflow/graph_optimizer.h
  3. +7
    -0
      tests/depends/ops_stub/ops_stub.h
  4. +174
    -0
      tests/st/testcase/origin_models/test_getnext_dynamic_fusion.pbtxt
  5. +17
    -0
      tests/st/testcase/test_tensorflow_parser.cc

+ 16
- 8
parser/tensorflow/graph_optimizer.cc View File

@@ -35,7 +35,8 @@ REGISTER_OPTYPE_DEFINE(TF_BATCH_MATMUL, "BatchMatmul");
namespace ge {
namespace {
const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal";
const char *const kShapeNodeName = "Shape";
const char *const kShapeNodeType = "Shape";
const char *const kShapeNodeNamePrefix = "getnext_shape_";
} // namespace

Status ParserGraphOptimizer::FusionFmkop() {
@@ -62,19 +63,25 @@ Status ParserGraphOptimizer::FusionFmkop() {
return SUCCESS;
}

Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>> &node_cluser_Map) {
Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>> &node_cluster_map) {
GE_CHECK_NOTNULL(graph_);
bool hasGetNext = false;
bool has_get_next = false;
for (auto node : graph_->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue);
string type = "";
GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type));
if (type == "IteratorGetNext") {
hasGetNext = true;
has_get_next = true;
break;
}
}
return GetFusionCluster(has_get_next, node_cluster_map);
}

Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next,
unordered_map<string, vector<NodePtr>> &node_cluster_map) {
GE_CHECK_NOTNULL(graph_);
for (auto node : graph_->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue)
@@ -97,7 +104,8 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>
NodePtr dst_node = in_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node);
GE_CHECK_NOTNULL(dst_node->GetOpDesc());
if (dst_node->GetOpDesc()->GetType() == kShapeNodeName) {
if ((dst_node->GetName().find(kShapeNodeNamePrefix) != std::string::npos) &&
(dst_node->GetOpDesc()->GetType() == kShapeNodeType)) {
temp_node_cluser.emplace_back(dst_node);
}
}
@@ -105,14 +113,14 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>
if (temp_node_cluser.size() > 1) {
vector<NodePtr> node_cluser;
node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end());
node_cluser_Map[temp_node_cluser[0]->GetName()] = node_cluser;
node_cluster_map[temp_node_cluser[0]->GetName()] = node_cluser;
}
temp_node_cluser.clear();
GELOGI("MarkForFusion, IteratorGetNext graph mark success.");
}

if (!hasGetNext && (type == "Iterator" || type == "IteratorV2")) {
GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluser_Map), "find framework node to be fused fail.");
if (!has_get_next && (type == "Iterator" || type == "IteratorV2")) {
GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluster_map), "find framework node to be fused fail.");
GELOGI("MarkForFusion, Iterator init graph mark success.");
}
}


+ 3
- 1
parser/tensorflow/graph_optimizer.h View File

@@ -41,7 +41,9 @@ class ParserGraphOptimizer {

domi::Status FindFmkNodeCluser(std::unordered_map<std::string, std::vector<ge::NodePtr>> &node_cluser_Map) const;

domi::Status MarkForFusion(std::unordered_map<std::string, std::vector<ge::NodePtr>> &node_cluser_Map);
domi::Status MarkForFusion(std::unordered_map<std::string, std::vector<ge::NodePtr>> &node_cluster_map);

domi::Status GetFusionCluster(const bool has_get_next, unordered_map<string, vector<NodePtr>> &node_cluster_map);

domi::Status UpdateGraph(std::vector<ge::NodePtr> &nodes);



+ 7
- 0
tests/depends/ops_stub/ops_stub.h View File

@@ -316,6 +316,13 @@ REG_OP(Softmax)
.ATTR(beta, Float, 0)
.OP_END_FACTORY_REG(Softmax)

REG_OP(Shape)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OUTPUT(y, TensorType({DT_INT32, DT_INT64}))
.ATTR(dtype, Int, DT_INT32)
.OP_END_FACTORY_REG(Shape)

// for plugin
static Status ParseParamsStub(const google::protobuf::Message* op_src, ge::Operator& op_dest) {
return SUCCESS;


+ 174
- 0
tests/st/testcase/origin_models/test_getnext_dynamic_fusion.pbtxt View File

@@ -0,0 +1,174 @@
node {
name: "IteratorV2"
op: "IteratorV2"
attr {
key: "op_def"
value {
s: "\n\007GetNext\032\032\n\ncomponents2\014output_types\"\036\n\014output_types\022\nlist(type)(\0010\001\" \n\routput_shapes\022\013list(shape)(\0010\001\"\026\n\014channel_name\022\006string\210\001\001"
}
}
attr {
key: "output_types"
value {
list {
type: DT_INT64
}
}
}
attr {
key: "output_tensor_desc"
value {
list {
func {
name: "0"
attr {
key: "serialize_datatype"
value: {
i: 9
}
}
attr {
key: "serialize_format"
value: {
i: 1
}
}
attr {
key: "serialize_shape"
value {
type: DT_INT32
}
}
}
}
}
}
}
node {
name: "IteratorGetNext"
op: "IteratorGetNext"
input: "IteratorV2"
attr {
key: "output_types"
value {
list {
type: DT_INT64
}
}
}
attr {
key: "op_def"
value {
s: "\n\007GetNext\032\032\n\ncomponents2\014output_types\"\036\n\014output_types\022\nlist(type)(\0010\001\" \n\routput_shapes\022\013list(shape)(\0010\001\"\026\n\014channel_name\022\006string\210\001\001"
}
}
attr {
key: "input_tensor_desc"
value {
list {
func {
name: "0"
attr {
key: "serialize_datatype"
value: {
i: 9
}
}
attr {
key: "serialize_format"
value: {
i: 1
}
}
attr {
key: "serialize_shape"
value {
type: DT_INT32
}
}
}
}
}
}
attr {
key: "output_tensor_desc"
value {
list {
func {
name: "0"
attr {
key: "serialize_datatype"
value: {
i: 9
}
}
attr {
key: "serialize_format"
value: {
i: 1
}
}
attr {
key: "serialize_shape"
value {
list {
i: -1
i: -1
}
}
}
}
}
}
}
}
node {
name: "getnext_shape_0"
op: "Shape"
input: "IteratorGetNext"
attr {
key: "op_def"
value {
s: "\n\005Shape\022\n\n\005input\"\001T\032\022\n\006output\"\010out_type\"\t\n\001T\022\004type\"\034\n\010out_type\022\004type\032\0020\003:\006\n\0042\002\003\t"
}
}
}
node {
name: "retval_GetNext_0_0"
op: "_Retval"
input: "IteratorGetNext"
attr {
key: "index"
value {
i: 0
}
}
attr {
key: "op_def"
value {
s: ""
}
}
}
node {
name: "retval_GetNext_0_1"
op: "_Retval"
input: "getnext_shape_0"
attr {
key: "index"
value {
i: 1
}
}
attr {
key: "op_def"
value {
s: ""
}
}
}
library {
}
versions {
producer: 134
}

+ 17
- 0
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -4225,4 +4225,21 @@ TEST_F(STestTensorflowParser, parser_UppdateInputMap_test)
delete graph;
}

TEST_F(STestTensorflowParser, tensorflow_optimizer_fmk_fusion_op) {
std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/");
caseDir = caseDir.substr(0, idx);
const std::string root_proto = caseDir + "/origin_models/test_getnext_dynamic_fusion.pbtxt";
domi::tensorflow::GraphDef graphDef;

bool protoRet = parser::ReadProtoFromText(root_proto.c_str(), &graphDef);
ASSERT_EQ(protoRet, true);

TensorFlowModelParser tensorflow_parser;
ge::ComputeGraphPtr root_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmp_graph");
Status ret = tensorflow_parser.ParseProto(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(root_graph->GetDirectNode().size(), 3);
}

} // namespace ge

Loading…
Cancel
Save