Browse Source

online_infer support for dynamic_dims

pull/77/head
zhou_lili 5 years ago
parent
commit
09028af074
2 changed files with 37 additions and 0 deletions
  1. +13
    -0
      parser/tensorflow/graph_optimizer.cc
  2. +24
    -0
      parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc

+ 13
- 0
parser/tensorflow/graph_optimizer.cc View File

@@ -88,6 +88,7 @@ REGISTER_OPTYPE_DEFINE(TF_BATCH_MATMUL, "BatchMatmul");
namespace ge {
namespace {
const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal";
const char *const kShapeNodeName = "Shape";
} // namespace

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map<string, OpSupportTranInfo> g_OpSupportTranInfo = {};
@@ -1313,6 +1314,18 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>
temp_node_cluser.push_back(src_node);
}
temp_node_cluser.push_back(node);
for (auto out_anchor : node->GetAllOutDataAnchors()) {
GE_CHECK_NOTNULL(out_anchor);
for (auto in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_CHECK_NOTNULL(in_anchor);
NodePtr dst_node = in_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node);
GE_CHECK_NOTNULL(dst_node->GetOpDesc());
if (dst_node->GetOpDesc()->GetType() == kShapeNodeName) {
temp_node_cluser.emplace_back(dst_node);
}
}
}
if (temp_node_cluser.size() > 1) {
vector<NodePtr> node_cluser;
node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end());


+ 24
- 0
parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc View File

@@ -69,6 +69,30 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge
op_dest->GetType().c_str(), dynamic_tensor_num);
}

// add nodedef for shape insert by adapter when online_infer_dynamic
if (op_dest->GetType() == SHAPE) {
std::shared_ptr<NodeDef> pkg_node = ge::parser::MakeShared<NodeDef>();
GE_CHECK_NOTNULL(pkg_node);
pkg_node->CopyFrom(*node);

// Get the property opdef, if the property does not exist, return failure
pkg_node->mutable_attr()->erase(ge::ATTR_NAME_FRAMEWORK_OP_DEF);
pkg_node->mutable_attr()->erase(ge::ATTR_NAME_OUTPUT_TENSOR_DESC);
pkg_node->mutable_attr()->erase(ge::ATTR_NAME_INPUT_TENSOR_DESC);
pkg_node->mutable_attr()->erase(ge::VAR_ATTR_NAME);

// Serialize nodedef into string and package as a whole
string serialized_node;
GE_IF_BOOL_EXEC(!pkg_node->SerializeToString(&serialized_node),
GELOGE(PARAM_INVALID, "In FrameworkOp trans NodeDef to string failed.");
return PARAM_INVALID);

(void)AttrUtils::SetZeroCopyBytes(
op_dest, ge::ATTR_NAME_FRAMEWORK_NODE_DEF,
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(serialized_node.data()), serialized_node.length()));
GELOGI("node_def of %s is %s.", op_dest->GetName().c_str(), serialized_node.c_str());
}

return SUCCESS;
}



Loading…
Cancel
Save