| @@ -1 +1 @@ | |||||
| Subproject commit f62cba4fdf845ffe04e5c1e37ea990d22c438910 | |||||
| Subproject commit 51f76677af9299a919440416af70471f191380b8 | |||||
| @@ -21,6 +21,7 @@ | |||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| #include "parser/common/parser_utils.h" | #include "parser/common/parser_utils.h" | ||||
| #include "graph/def_types.h" | |||||
| using domi::ONNX; | using domi::ONNX; | ||||
| using domi::ParseParamByOpFunc; | using domi::ParseParamByOpFunc; | ||||
| @@ -29,7 +30,7 @@ using domi::ParseParamFunc; | |||||
| namespace ge { | namespace ge { | ||||
| Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator &op_dest) { | Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator &op_dest) { | ||||
| GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
| const ge::onnx::NodeProto *node_src = reinterpret_cast<const ge::onnx::NodeProto *>(op_src); | |||||
| const ge::onnx::NodeProto *node_src = PtrToPtr<const Message, const ge::onnx::NodeProto>(op_src); | |||||
| GE_CHECK_NOTNULL(node_src); | GE_CHECK_NOTNULL(node_src); | ||||
| GELOGI("Onnx op node name = %s, op type= %s, parse params.", node_src->name().c_str(), node_src->op_type().c_str()); | GELOGI("Onnx op node name = %s, op type= %s, parse params.", node_src->name().c_str(), node_src->op_type().c_str()); | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/def_types.h" | |||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "framework/omg/parser/parser_inner_ctx.h" | #include "framework/omg/parser/parser_inner_ctx.h" | ||||
| #include "parser/onnx/onnx_util.h" | #include "parser/onnx/onnx_util.h" | ||||
| @@ -29,7 +30,7 @@ using namespace ge::parser; | |||||
| namespace ge { | namespace ge { | ||||
| Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) { | Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) { | ||||
| GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
| const ge::onnx::NodeProto *node_src = reinterpret_cast<const ge::onnx::NodeProto *>(op_src); | |||||
| const ge::onnx::NodeProto *node_src = PtrToPtr<const Message, const ge::onnx::NodeProto>(op_src); | |||||
| GE_CHECK_NOTNULL(node_src); | GE_CHECK_NOTNULL(node_src); | ||||
| GELOGD("Onnx op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op_type().c_str()); | GELOGD("Onnx op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op_type().c_str()); | ||||
| if (ParseInputFromModel(op_src, op_def) != SUCCESS) { | if (ParseInputFromModel(op_src, op_def) != SUCCESS) { | ||||
| @@ -73,7 +74,7 @@ int64_t OnnxDataParser::ParseInputTensor(const ge::onnx::AttributeProto &attribu | |||||
| Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator &op_def) { | Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator &op_def) { | ||||
| GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
| const ge::onnx::NodeProto *node = reinterpret_cast<const ge::onnx::NodeProto *>(op_src); | |||||
| const ge::onnx::NodeProto *node = PtrToPtr<const Message, const ge::onnx::NodeProto>(op_src); | |||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| // Get attr t:'input_tensor' form NodeProto | // Get attr t:'input_tensor' form NodeProto | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include "framework/omg/parser/parser_types.h" | #include "framework/omg/parser/parser_types.h" | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/def_types.h" | |||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| #include "register/register.h" | #include "register/register.h" | ||||
| @@ -43,7 +44,7 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge | |||||
| GELOGE(PARAM_INVALID, "Op src is null"); | GELOGE(PARAM_INVALID, "Op src is null"); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | |||||
| const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src); | |||||
| GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | ||||
| if (op_dest == nullptr) { | if (op_dest == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Param op_dest is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Param op_dest is nullptr, check invalid"); | ||||
| @@ -31,7 +31,7 @@ Status TensorFlowEnterParser::ParseParams(const Message *op_src, ge::OpDescPtr & | |||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| const std::string name = op_desc->GetName(); | const std::string name = op_desc->GetName(); | ||||
| const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | |||||
| const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src); | |||||
| domi::tensorflow::AttrValue str_attr; | domi::tensorflow::AttrValue str_attr; | ||||
| if (!TensorFlowUtil::FindAttrValue(node, ENTER_ATTR_FRAME_NAME, str_attr)) { | if (!TensorFlowUtil::FindAttrValue(node, ENTER_ATTR_FRAME_NAME, str_attr)) { | ||||
| REPORT_CALL_ERROR("E19999", "In NodeDef:%s attr:%s not exist, check invalid", | REPORT_CALL_ERROR("E19999", "In NodeDef:%s attr:%s not exist, check invalid", | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "framework/omg/parser/parser_types.h" | #include "framework/omg/parser/parser_types.h" | ||||
| #include "graph/def_types.h" | |||||
| using domi::TENSORFLOW; | using domi::TENSORFLOW; | ||||
| using ge::parser::MERGE; | using ge::parser::MERGE; | ||||
| @@ -30,7 +31,7 @@ Status TensorFlowMergeParser::ParseParams(const Message *op_src, ge::OpDescPtr & | |||||
| GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | |||||
| const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src); | |||||
| domi::tensorflow::AttrValue attr_num; | domi::tensorflow::AttrValue attr_num; | ||||
| if (!(TensorFlowUtil::FindAttrValue(node, ATTR_NAME_N, attr_num))) { | if (!(TensorFlowUtil::FindAttrValue(node, ATTR_NAME_N, attr_num))) { | ||||
| GELOGW("In NodeDef %s dynamic attr [%s] is not exist.", op_desc->GetName().c_str(), ATTR_NAME_N.c_str()); | GELOGW("In NodeDef %s dynamic attr [%s] is not exist.", op_desc->GetName().c_str(), ATTR_NAME_N.c_str()); | ||||
| @@ -42,10 +42,9 @@ Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &att | |||||
| ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); | ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
| return PARAM_INVALID); | return PARAM_INVALID); | ||||
| // calculate size | // calculate size | ||||
| int64_t tmp_dim = 0; | |||||
| int64_t real_size = 1; | int64_t real_size = 1; | ||||
| for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | ||||
| tmp_dim = ge_desc.GetShape().GetDim(j); | |||||
| int64_t tmp_dim = ge_desc.GetShape().GetDim(j); | |||||
| GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); | GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); | ||||
| real_size *= tmp_dim; | real_size *= tmp_dim; | ||||
| } | } | ||||
| @@ -47,9 +47,8 @@ Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &att | |||||
| return domi::PARAM_INVALID); | return domi::PARAM_INVALID); | ||||
| // calculate size | // calculate size | ||||
| int64_t real_size = 1; | int64_t real_size = 1; | ||||
| int64_t tmp_dim = 0; | |||||
| for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | ||||
| tmp_dim = ge_desc.GetShape().GetDim(j); | |||||
| int64_t tmp_dim = ge_desc.GetShape().GetDim(j); | |||||
| GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); | GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); | ||||
| PARSER_INT64_MULCHECK(real_size, tmp_dim); | PARSER_INT64_MULCHECK(real_size, tmp_dim); | ||||
| real_size *= tmp_dim; | real_size *= tmp_dim; | ||||
| @@ -271,9 +271,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr | |||||
| GE_CHK_BOOL_RET_STATUS(ge::TypeUtils::GetDataTypeLength(data_type, size_type), PARAM_INVALID, | GE_CHK_BOOL_RET_STATUS(ge::TypeUtils::GetDataTypeLength(data_type, size_type), PARAM_INVALID, | ||||
| "dataType no define size , parse ge_desc failed."); | "dataType no define size , parse ge_desc failed."); | ||||
| // get size | // get size | ||||
| int64_t tmp_dim = 0; | |||||
| for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | ||||
| tmp_dim = ge_desc.GetShape().GetDim(j); | |||||
| int64_t tmp_dim = ge_desc.GetShape().GetDim(j); | |||||
| // The shape infered by fusedbatchnormgrad and mean calling tensorflow is not accurate. | // The shape infered by fusedbatchnormgrad and mean calling tensorflow is not accurate. | ||||
| // Here, special treatment is given to the two operators. | // Here, special treatment is given to the two operators. | ||||