diff --git a/parser/tensorflow/graph_optimizer.cc b/parser/tensorflow/graph_optimizer.cc index 373344f..9e2f875 100644 --- a/parser/tensorflow/graph_optimizer.cc +++ b/parser/tensorflow/graph_optimizer.cc @@ -88,6 +88,7 @@ REGISTER_OPTYPE_DEFINE(TF_BATCH_MATMUL, "BatchMatmul"); namespace ge { namespace { const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal"; +const char *const kShapeNodeName = "Shape"; } // namespace FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map g_OpSupportTranInfo = {}; @@ -1313,6 +1314,18 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map temp_node_cluser.push_back(src_node); } temp_node_cluser.push_back(node); + for (auto out_anchor : node->GetAllOutDataAnchors()) { + GE_CHECK_NOTNULL(out_anchor); + for (auto in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_CHECK_NOTNULL(in_anchor); + NodePtr dst_node = in_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(dst_node); + GE_CHECK_NOTNULL(dst_node->GetOpDesc()); + if (dst_node->GetOpDesc()->GetType() == kShapeNodeName) { + temp_node_cluser.emplace_back(dst_node); + } + } + } if (temp_node_cluser.size() > 1) { vector node_cluser; node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end()); diff --git a/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc index e9fe078..47ab56b 100644 --- a/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc +++ b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc @@ -69,6 +69,30 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge op_dest->GetType().c_str(), dynamic_tensor_num); } + // add nodedef for shape insert by adapter when online_infer_dynamic + if (op_dest->GetType() == SHAPE) { + std::shared_ptr pkg_node = ge::parser::MakeShared(); + GE_CHECK_NOTNULL(pkg_node); + pkg_node->CopyFrom(*node); + + // Get the property opdef, if the property does not exist, return failure + pkg_node->mutable_attr()->erase(ge::ATTR_NAME_FRAMEWORK_OP_DEF); + pkg_node->mutable_attr()->erase(ge::ATTR_NAME_OUTPUT_TENSOR_DESC); + pkg_node->mutable_attr()->erase(ge::ATTR_NAME_INPUT_TENSOR_DESC); + pkg_node->mutable_attr()->erase(ge::VAR_ATTR_NAME); + + // Serialize nodedef into string and package as a whole + string serialized_node; + GE_IF_BOOL_EXEC(!pkg_node->SerializeToString(&serialized_node), + GELOGE(PARAM_INVALID, "In FrameworkOp trans NodeDef to string failed."); + return PARAM_INVALID); + + (void)AttrUtils::SetZeroCopyBytes( + op_dest, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, + Buffer::CopyFrom(reinterpret_cast(serialized_node.data()), serialized_node.length())); + GELOGI("node_def of %s is %s.", op_dest->GetName().c_str(), serialized_node.c_str()); + } + return SUCCESS; }