Merge pull request !579 from zhangfan/ge_devpull/585/MERGE
| @@ -206,6 +206,14 @@ void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr paren | |||
| } | |||
| GELOGD("Add dump origin name %s for node %s.", original_names[0].c_str(), node->GetName().c_str()); | |||
| } | |||
| void AddDumpOriginNameForRootGraph(const ge::ComputeGraphPtr& graph) { | |||
| for (auto &node : graph->GetDirectNode()) { | |||
| if (ge::AttrUtils::SetListStr(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, {node->GetName()})) { | |||
| GELOGD("Add dump origin name %s for node %s.", node->GetName().c_str(), | |||
| node->GetName().c_str()); | |||
| } | |||
| } | |||
| } | |||
| } // namespace ge | |||
| namespace ge { | |||
| @@ -273,6 +281,7 @@ Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque | |||
| Status PostOpProcessForSubgraph(const ParseArg &arg) { | |||
| if (arg.parent_node == nullptr) { | |||
| AddDumpOriginNameForRootGraph(arg.graph); | |||
| return SUCCESS; | |||
| } | |||
| std::string op_type = arg.parent_node->GetType(); | |||
| @@ -21,7 +21,9 @@ | |||
| #include "parser/common/util.h" | |||
| #include "parser/tensorflow/tensorflow_util.h" | |||
| #include "parser/common/acl_graph_parser_util.h" | |||
| #include "parser/common/parser_utils.h" | |||
| #include "omg/parser/parser_inner_ctx.h" | |||
| #include "register/register_utils.h" | |||
| using domi::TENSORFLOW; | |||
| using namespace ge::parser; | |||
| @@ -57,9 +59,14 @@ Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &att | |||
| return SUCCESS; | |||
| } | |||
| Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { | |||
| Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | |||
| GE_CHECK_NOTNULL(op_src); | |||
| GE_CHECK_NOTNULL(op); | |||
| GE_CHECK_NOTNULL(op_dest); | |||
| ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); | |||
| GE_CHK_STATUS_RET(domi::OperatorAutoMapping(op_src, op), | |||
| "call auto mapping failed for node:%s", ParserUtils::GetOperatorName(op).c_str()); | |||
| op.BreakConnect(); | |||
| const domi::tensorflow::NodeDef *node_src = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); | |||
| GE_CHECK_NOTNULL(node_src); | |||
| @@ -82,10 +89,10 @@ Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||
| "parse output desc failed"); | |||
| } | |||
| GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc), FAILED, | |||
| GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc), FAILED, | |||
| "set input desc failed"); | |||
| GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc), FAILED, | |||
| GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc), FAILED, | |||
| "set output desc failed");); | |||
| return SUCCESS; | |||
| @@ -34,7 +34,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowReshapeParser : public TensorFlowOpParser | |||
| * @return FAILED parse failed | |||
| * @author | |||
| */ | |||
| Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override; | |||
| Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | |||
| }; | |||
| } // namespace ge | |||
| @@ -23,6 +23,8 @@ | |||
| #include "graph/utils/type_utils.h" | |||
| #include "parser/common/op_parser_factory.h" | |||
| #include "parser/common/acl_graph_parser_util.h" | |||
| #include "parser/common/parser_utils.h" | |||
| #include "register/register_utils.h" | |||
| using domi::tensorflow::AttrValue; | |||
| using std::vector; | |||
| @@ -62,24 +64,24 @@ Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &att | |||
| return SUCCESS; | |||
| } | |||
| Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { | |||
| Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | |||
| GE_CHECK_NOTNULL(op_src); | |||
| GE_CHECK_NOTNULL(op); | |||
| GE_CHECK_NOTNULL(op_dest); | |||
| ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); | |||
| GE_CHK_STATUS_RET(domi::OperatorAutoMapping(op_src, op), | |||
| "call auto mapping failed for node:%s", ParserUtils::GetOperatorName(op).c_str()); | |||
| op.BreakConnect(); | |||
| const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); | |||
| GE_CHECK_NOTNULL(node); | |||
| GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | |||
| bool has_axis = true; | |||
| bool has_dims = true; | |||
| domi::tensorflow::AttrValue axis; | |||
| domi::tensorflow::AttrValue dims; | |||
| if (!TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis)) { | |||
| has_axis = false; | |||
| } | |||
| if (!TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims)) { | |||
| has_dims = false; | |||
| } | |||
| bool has_axis = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis); | |||
| bool has_dims = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims); | |||
| if (!has_axis && !has_dims) { | |||
| return SUCCESS; | |||
| } | |||
| @@ -103,9 +105,9 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||
| int32_t result = values.i(i); | |||
| v_result.push_back(result); | |||
| } | |||
| if (!ge::AttrUtils::SetListInt(op, SQUEEZE_ATTR_AXIS, v_result)) { | |||
| if (!ge::AttrUtils::SetListInt(op_dest, SQUEEZE_ATTR_AXIS, v_result)) { | |||
| REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", SQUEEZE_ATTR_AXIS.c_str(), | |||
| op->GetName().c_str(), op->GetType().c_str()); | |||
| op_dest->GetName().c_str(), op_dest->GetType().c_str()); | |||
| GELOGE(FAILED, "Set squeeze axis attr failed"); | |||
| return FAILED; | |||
| } | |||
| @@ -125,14 +127,14 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||
| "parse output desc failed"); | |||
| } | |||
| if (!ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc)) { | |||
| if (!ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc)) { | |||
| REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", RESHAPE_ATTR_NAME_INPUT_DESC.c_str(), | |||
| op->GetName().c_str(), op->GetType().c_str()); | |||
| op_dest->GetName().c_str(), op_dest->GetType().c_str()); | |||
| GELOGE(FAILED, "Set input desc failed"); | |||
| return FAILED; | |||
| } if (!ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc)) { | |||
| } if (!ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc)) { | |||
| REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", RESHAPE_ATTR_NAME_OUTPUT_DESC.c_str(), | |||
| op->GetName().c_str(), op->GetType().c_str()); | |||
| op_dest->GetName().c_str(), op_dest->GetType().c_str()); | |||
| GELOGE(FAILED, "Set output desc failed"); | |||
| return FAILED; | |||
| }) | |||
| @@ -22,7 +22,7 @@ | |||
| namespace ge { | |||
| class PARSER_FUNC_VISIBILITY TensorFlowSqueezeParser : public TensorFlowOpParser { | |||
| public: | |||
| Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override; | |||
| Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | |||
| private: | |||
| static Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc); | |||