Browse Source

!579 sync parser to master 20220625

Merge pull request !579 from zhangfan/ge_dev
pull/585/MERGE
zhangfan Gitee 3 years ago
parent
commit
aa77814e5f
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 40 additions and 22 deletions
  1. +9
    -0
      parser/tensorflow/tensorflow_parser.cc
  2. +11
    -4
      parser/tensorflow/tensorflow_reshape_parser.cc
  3. +1
    -1
      parser/tensorflow/tensorflow_reshape_parser.h
  4. +18
    -16
      parser/tensorflow/tensorflow_squeeze_parser.cc
  5. +1
    -1
      parser/tensorflow/tensorflow_squeeze_parser.h

+ 9
- 0
parser/tensorflow/tensorflow_parser.cc View File

@@ -206,6 +206,14 @@ void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr paren
}
GELOGD("Add dump origin name %s for node %s.", original_names[0].c_str(), node->GetName().c_str());
}
void AddDumpOriginNameForRootGraph(const ge::ComputeGraphPtr& graph) {
for (auto &node : graph->GetDirectNode()) {
if (ge::AttrUtils::SetListStr(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, {node->GetName()})) {
GELOGD("Add dump origin name %s for node %s.", node->GetName().c_str(),
node->GetName().c_str());
}
}
}
} // namespace ge

namespace ge {
@@ -273,6 +281,7 @@ Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque

Status PostOpProcessForSubgraph(const ParseArg &arg) {
if (arg.parent_node == nullptr) {
AddDumpOriginNameForRootGraph(arg.graph);
return SUCCESS;
}
std::string op_type = arg.parent_node->GetType();


+ 11
- 4
parser/tensorflow/tensorflow_reshape_parser.cc View File

@@ -21,7 +21,9 @@
#include "parser/common/util.h"
#include "parser/tensorflow/tensorflow_util.h"
#include "parser/common/acl_graph_parser_util.h"
#include "parser/common/parser_utils.h"
#include "omg/parser/parser_inner_ctx.h"
#include "register/register_utils.h"

using domi::TENSORFLOW;
using namespace ge::parser;
@@ -57,9 +59,14 @@ Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &att
return SUCCESS;
}

Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) {
Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) {
GE_CHECK_NOTNULL(op_src);
GE_CHECK_NOTNULL(op);
GE_CHECK_NOTNULL(op_dest);

ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest);
GE_CHK_STATUS_RET(domi::OperatorAutoMapping(op_src, op),
"call auto mapping failed for node:%s", ParserUtils::GetOperatorName(op).c_str());
op.BreakConnect();

const domi::tensorflow::NodeDef *node_src = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src);
GE_CHECK_NOTNULL(node_src);
@@ -82,10 +89,10 @@ Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr
"parse output desc failed");
}

GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc), FAILED,
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc), FAILED,
"set input desc failed");

GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc), FAILED,
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc), FAILED,
"set output desc failed"););

return SUCCESS;


+ 1
- 1
parser/tensorflow/tensorflow_reshape_parser.h View File

@@ -34,7 +34,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowReshapeParser : public TensorFlowOpParser
* @return FAILED parse failed
* @author
*/
Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override;
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override;
};
} // namespace ge



+ 18
- 16
parser/tensorflow/tensorflow_squeeze_parser.cc View File

@@ -23,6 +23,8 @@
#include "graph/utils/type_utils.h"
#include "parser/common/op_parser_factory.h"
#include "parser/common/acl_graph_parser_util.h"
#include "parser/common/parser_utils.h"
#include "register/register_utils.h"

using domi::tensorflow::AttrValue;
using std::vector;
@@ -62,24 +64,24 @@ Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &att
return SUCCESS;
}

Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) {
Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) {
GE_CHECK_NOTNULL(op_src);
GE_CHECK_NOTNULL(op);
GE_CHECK_NOTNULL(op_dest);

ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest);
GE_CHK_STATUS_RET(domi::OperatorAutoMapping(op_src, op),
"call auto mapping failed for node:%s", ParserUtils::GetOperatorName(op).c_str());
op.BreakConnect();

const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src);
GE_CHECK_NOTNULL(node);
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str());
bool has_axis = true;
bool has_dims = true;

domi::tensorflow::AttrValue axis;
domi::tensorflow::AttrValue dims;
if (!TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis)) {
has_axis = false;
}
if (!TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims)) {
has_dims = false;
}

bool has_axis = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis);
bool has_dims = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims);
if (!has_axis && !has_dims) {
return SUCCESS;
}
@@ -103,9 +105,9 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr
int32_t result = values.i(i);
v_result.push_back(result);
}
if (!ge::AttrUtils::SetListInt(op, SQUEEZE_ATTR_AXIS, v_result)) {
if (!ge::AttrUtils::SetListInt(op_dest, SQUEEZE_ATTR_AXIS, v_result)) {
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", SQUEEZE_ATTR_AXIS.c_str(),
op->GetName().c_str(), op->GetType().c_str());
op_dest->GetName().c_str(), op_dest->GetType().c_str());
GELOGE(FAILED, "Set squeeze axis attr failed");
return FAILED;
}
@@ -125,14 +127,14 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr
"parse output desc failed");
}

if (!ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc)) {
if (!ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc)) {
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", RESHAPE_ATTR_NAME_INPUT_DESC.c_str(),
op->GetName().c_str(), op->GetType().c_str());
op_dest->GetName().c_str(), op_dest->GetType().c_str());
GELOGE(FAILED, "Set input desc failed");
return FAILED;
} if (!ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc)) {
} if (!ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc)) {
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", RESHAPE_ATTR_NAME_OUTPUT_DESC.c_str(),
op->GetName().c_str(), op->GetType().c_str());
op_dest->GetName().c_str(), op_dest->GetType().c_str());
GELOGE(FAILED, "Set output desc failed");
return FAILED;
})


+ 1
- 1
parser/tensorflow/tensorflow_squeeze_parser.h View File

@@ -22,7 +22,7 @@
namespace ge {
class PARSER_FUNC_VISIBILITY TensorFlowSqueezeParser : public TensorFlowOpParser {
public:
Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override;
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override;

private:
static Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc);


Loading…
Cancel
Save